mirror of
https://github.com/tiennm99/zfoo.git
synced 2026-05-21 04:25:00 +00:00
perf[go]: 简化golang的zfoo服务器
This commit is contained in:
@@ -1,153 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2020 The zfoo Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
|
||||
package znet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Conn wrap net.Conn
|
||||
type Conn struct {
|
||||
sid string
|
||||
rawConn net.Conn
|
||||
sendCh chan []byte
|
||||
messageCh chan any
|
||||
done chan error
|
||||
hbTimer *time.Timer
|
||||
name string
|
||||
hbInterval time.Duration
|
||||
hbTimeout time.Duration
|
||||
}
|
||||
|
||||
// GetName Get conn name
|
||||
func (c *Conn) GetName() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// NewConn create new conn
|
||||
func NewConn(c net.Conn, hbInterval time.Duration, hbTimeout time.Duration) *Conn {
|
||||
conn := &Conn{
|
||||
rawConn: c,
|
||||
sendCh: make(chan []byte, 100),
|
||||
done: make(chan error),
|
||||
messageCh: make(chan any, 100),
|
||||
hbInterval: hbInterval,
|
||||
hbTimeout: hbTimeout,
|
||||
}
|
||||
|
||||
conn.name = c.RemoteAddr().String()
|
||||
conn.hbTimer = time.NewTimer(conn.hbInterval)
|
||||
|
||||
if conn.hbInterval == 0 {
|
||||
conn.hbTimer.Stop()
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// Close close connection
|
||||
func (c *Conn) Close() {
|
||||
c.hbTimer.Stop()
|
||||
c.rawConn.Close()
|
||||
}
|
||||
|
||||
// SendMessage send message
|
||||
func (c *Conn) SendMessage(msg *Packet) error {
|
||||
var buffer = Encode(msg)
|
||||
c.sendCh <- buffer.ToBytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCoroutine write coroutine
|
||||
func (c *Conn) writeCoroutine(ctx context.Context) {
|
||||
hbData := make([]byte, 0)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case pkt := <-c.sendCh:
|
||||
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := c.rawConn.Write(pkt); err != nil {
|
||||
c.done <- err
|
||||
}
|
||||
|
||||
case <-c.hbTimer.C:
|
||||
hbMessage := NewMessage(MsgHeartbeat, hbData)
|
||||
c.SendMessage(hbMessage)
|
||||
// 设置心跳timer
|
||||
if c.hbInterval > 0 {
|
||||
c.hbTimer.Reset(c.hbInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readCoroutine read coroutine
|
||||
func (c *Conn) readCoroutine(ctx context.Context) {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
default:
|
||||
// 设置超时
|
||||
if c.hbInterval > 0 {
|
||||
err := c.rawConn.SetReadDeadline(time.Now().Add(c.hbTimeout))
|
||||
if err != nil {
|
||||
c.done <- err
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 读取长度
|
||||
buf := make([]byte, 4)
|
||||
_, err := io.ReadFull(c.rawConn, buf)
|
||||
if err != nil {
|
||||
c.done <- err
|
||||
continue
|
||||
}
|
||||
|
||||
bufReader := bytes.NewReader(buf)
|
||||
|
||||
var dataSize int32
|
||||
err = binary.Read(bufReader, binary.BigEndian, &dataSize)
|
||||
if err != nil {
|
||||
c.done <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var bytes = make([]byte, dataSize)
|
||||
_, err = io.ReadFull(c.rawConn, bytes)
|
||||
if err != nil {
|
||||
c.done <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// 解码
|
||||
var packet = Decode(bytes)
|
||||
c.messageCh <- packet
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2020 The zfoo Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
|
||||
package znet
|
||||
|
||||
const (
|
||||
// STUnknown Unknown
|
||||
STUnknown = iota
|
||||
// STInited Inited
|
||||
STInited
|
||||
// STRunning Running
|
||||
STRunning
|
||||
// STStop Stop
|
||||
STStop
|
||||
)
|
||||
|
||||
const (
|
||||
// MsgHeartbeat heartbeat
|
||||
MsgHeartbeat = iota
|
||||
)
|
||||
@@ -1,38 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2020 The zfoo Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package znet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Packet struct
|
||||
type Packet struct {
|
||||
length int32
|
||||
protocolId int16
|
||||
data []byte
|
||||
}
|
||||
|
||||
// NewMessage create a new message
|
||||
func NewMessage(protocolId int16, data []byte) *Packet {
|
||||
msg := &Packet{
|
||||
length: int32(len(data)) + 2 + 4,
|
||||
protocolId: protocolId,
|
||||
data: data,
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
|
||||
func (msg *Packet) String() string {
|
||||
return fmt.Sprintf("Size=%d ID=%d DataLen=%d", msg.length, msg.protocolId, len(msg.data))
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
/*
|
||||
* Copyright (C) 2020 The zfoo Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package znet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Server struct
|
||||
type Server struct {
|
||||
onMessage func(*Session, any)
|
||||
onConnect func(*Session)
|
||||
onDisconnect func(*Session, error)
|
||||
sessions *sync.Map
|
||||
address string
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer create a new socket service
|
||||
func NewServer(addr string) *Server {
|
||||
listen, _ := net.Listen("tcp", addr)
|
||||
server := &Server{
|
||||
sessions: &sync.Map{},
|
||||
address: addr,
|
||||
listener: listen,
|
||||
}
|
||||
return server
|
||||
}
|
||||
|
||||
|
||||
// Start Start socket service
|
||||
func (s *Server) Start() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
defer func() {
|
||||
cancel()
|
||||
s.listener.Close()
|
||||
}()
|
||||
|
||||
s.acceptHandler(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) acceptHandler(ctx context.Context) {
|
||||
for {
|
||||
conn, _ := s.listener.Accept()
|
||||
go s.connectHandler(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) connectHandler(ctx context.Context, c net.Conn) {
|
||||
var session = NewSession(c)
|
||||
s.sessions.Store(session.sid, session)
|
||||
|
||||
connctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
defer func() {
|
||||
cancel()
|
||||
session.Close()
|
||||
s.sessions.Delete(session.sid)
|
||||
}()
|
||||
|
||||
go session.readCoroutine(connctx)
|
||||
go session.writeCoroutine(connctx)
|
||||
|
||||
if s.onConnect != nil {
|
||||
s.onConnect(session)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-session.done:
|
||||
|
||||
if s.onDisconnect != nil {
|
||||
s.onDisconnect(session, err)
|
||||
}
|
||||
return
|
||||
|
||||
case packet := <-session.messageCh:
|
||||
if s.onMessage != nil {
|
||||
s.onMessage(session, packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2020 The zfoo Authors
|
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
|
||||
* in compliance with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
|
||||
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and limitations under the License.
|
||||
*/
|
||||
package znet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SocketService struct
|
||||
type SocketService struct {
|
||||
onMessage func(*Session, any)
|
||||
onConnect func(*Session)
|
||||
onDisconnect func(*Session, error)
|
||||
sessions *sync.Map
|
||||
hbInterval time.Duration
|
||||
hbTimeout time.Duration
|
||||
laddr string
|
||||
status int
|
||||
listener net.Listener
|
||||
stopCh chan error
|
||||
}
|
||||
|
||||
// Server create a new socket service
|
||||
func Server(laddr string) *SocketService {
|
||||
|
||||
listen, _ := net.Listen("tcp", laddr)
|
||||
|
||||
server := &SocketService{
|
||||
sessions: &sync.Map{},
|
||||
stopCh: make(chan error),
|
||||
hbInterval: 0 * time.Second,
|
||||
hbTimeout: 0 * time.Second,
|
||||
laddr: laddr,
|
||||
status: STInited,
|
||||
listener: listen,
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// RegMessageHandler register message handler
|
||||
func (s *SocketService) RegMessageHandler(handler func(*Session, any)) {
|
||||
s.onMessage = handler
|
||||
}
|
||||
|
||||
// RegConnectHandler register connect handler
|
||||
func (s *SocketService) RegConnectHandler(handler func(*Session)) {
|
||||
s.onConnect = handler
|
||||
}
|
||||
|
||||
// RegDisconnectHandler register disconnect handler
|
||||
func (s *SocketService) RegDisconnectHandler(handler func(*Session, error)) {
|
||||
s.onDisconnect = handler
|
||||
}
|
||||
|
||||
// Start Start socket service
|
||||
func (s *SocketService) Start() {
|
||||
|
||||
s.status = STRunning
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
defer func() {
|
||||
s.status = STStop
|
||||
cancel()
|
||||
s.listener.Close()
|
||||
}()
|
||||
|
||||
go s.acceptHandler(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocketService) acceptHandler(ctx context.Context) {
|
||||
for {
|
||||
c, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
s.stopCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
go s.connectHandler(ctx, c)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocketService) connectHandler(ctx context.Context, c net.Conn) {
|
||||
conn := NewConn(c, s.hbInterval, s.hbTimeout)
|
||||
session := NewSession(conn)
|
||||
s.sessions.Store(session.sid, session)
|
||||
|
||||
connctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
defer func() {
|
||||
cancel()
|
||||
conn.Close()
|
||||
s.sessions.Delete(session.sid)
|
||||
}()
|
||||
|
||||
go conn.readCoroutine(connctx)
|
||||
go conn.writeCoroutine(connctx)
|
||||
|
||||
if s.onConnect != nil {
|
||||
s.onConnect(session)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-conn.done:
|
||||
|
||||
if s.onDisconnect != nil {
|
||||
s.onDisconnect(session, err)
|
||||
}
|
||||
return
|
||||
|
||||
case packet := <-conn.messageCh:
|
||||
if s.onMessage != nil {
|
||||
s.onMessage(session, packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus get socket service status
|
||||
func (s *SocketService) GetStatus() int {
|
||||
return s.status
|
||||
}
|
||||
|
||||
// Stop stop socket service with reason
|
||||
func (s *SocketService) Stop(reason string) {
|
||||
s.stopCh <- errors.New(reason)
|
||||
}
|
||||
|
||||
// SetHeartBeat set heart beat
|
||||
func (s *SocketService) SetHeartBeat(hbInterval time.Duration, hbTimeout time.Duration) error {
|
||||
if s.status == STRunning {
|
||||
return errors.New("Can't set heart beat on service running")
|
||||
}
|
||||
|
||||
s.hbInterval = hbInterval
|
||||
s.hbTimeout = hbTimeout
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConnsCount get connect count
|
||||
func (s *SocketService) GetConnsCount() int {
|
||||
var count int
|
||||
s.sessions.Range(func(k, v interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// Unicast Unicast with session ID
|
||||
func (s *SocketService) Unicast(sid string, msg *Packet) {
|
||||
v, ok := s.sessions.Load(sid)
|
||||
if ok {
|
||||
session := v.(*Session)
|
||||
err := session.conn.SendMessage(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast Broadcast to all connections
|
||||
func (s *SocketService) Broadcast(msg *Packet) {
|
||||
s.sessions.Range(func(k, v interface{}) bool {
|
||||
s := v.(*Session)
|
||||
if err := s.conn.SendMessage(msg); err != nil {
|
||||
// log.Println(err)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -11,26 +11,116 @@
|
||||
*/
|
||||
package znet
|
||||
|
||||
import "sync/atomic"
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Session struct
|
||||
type Session struct {
|
||||
sid uint64
|
||||
uid uint64
|
||||
conn *Conn
|
||||
|
||||
rawConn net.Conn
|
||||
sendCh chan []byte
|
||||
messageCh chan any
|
||||
done chan error
|
||||
}
|
||||
|
||||
var uuid uint64
|
||||
|
||||
// NewSession create a new session
|
||||
func NewSession(conn *Conn) *Session {
|
||||
func NewSession(conn net.Conn) *Session {
|
||||
var suuid = atomic.AddUint64(&uuid, 1)
|
||||
|
||||
session := &Session{
|
||||
sid: suuid,
|
||||
uid: 0,// 可以为用户的id
|
||||
conn: conn,
|
||||
uid: 0, // 可以为用户的id
|
||||
|
||||
rawConn: conn,
|
||||
sendCh: make(chan []byte, 100),
|
||||
done: make(chan error),
|
||||
messageCh: make(chan any, 100),
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Close close connection
|
||||
func (session *Session) Close() {
|
||||
session.rawConn.Close()
|
||||
}
|
||||
|
||||
// SendMessage send message
|
||||
func (session *Session) SendMessage(msg any) error {
|
||||
var buffer = Encode(msg)
|
||||
session.sendCh <- buffer.ToBytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeCoroutine write coroutine
|
||||
func (session *Session) writeCoroutine(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case pkt := <-session.sendCh:
|
||||
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := session.rawConn.Write(pkt); err != nil {
|
||||
session.done <- err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readCoroutine read coroutine
|
||||
func (session *Session) readCoroutine(ctx context.Context) {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
default:
|
||||
// 读取长度
|
||||
buf := make([]byte, 4)
|
||||
_, err := io.ReadFull(session.rawConn, buf)
|
||||
if err != nil {
|
||||
session.done <- err
|
||||
continue
|
||||
}
|
||||
|
||||
bufReader := bytes.NewReader(buf)
|
||||
|
||||
var dataSize int32
|
||||
err = binary.Read(bufReader, binary.BigEndian, &dataSize)
|
||||
if err != nil {
|
||||
session.done <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// 读取数据
|
||||
var bytes = make([]byte, dataSize)
|
||||
_, err = io.ReadFull(session.rawConn, bytes)
|
||||
if err != nil {
|
||||
session.done <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// 解码
|
||||
var packet = Decode(bytes)
|
||||
session.messageCh <- packet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,34 +19,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
host := "127.0.0.1:9000"
|
||||
func TestServer(t *testing.T) {
|
||||
var host = "127.0.0.1:9000"
|
||||
|
||||
server, _ := Server(host)
|
||||
server.RegMessageHandler(HandleMessage)
|
||||
server.RegConnectHandler(HandleConnect)
|
||||
server.RegDisconnectHandler(HandleDisconnect)
|
||||
var server = NewServer(host)
|
||||
server.onMessage = HandleMessage
|
||||
server.onConnect = HandleConnect
|
||||
server.onDisconnect = HandleDisconnect
|
||||
|
||||
server.Start()
|
||||
|
||||
// clientTest()
|
||||
}
|
||||
|
||||
|
||||
func HandleMessage(s *Session, packet any) {
|
||||
fmt.Println("receive packet")
|
||||
fmt.Println(packet)
|
||||
}
|
||||
|
||||
func HandleDisconnect(s *Session, err error) {
|
||||
fmt.Println(s.conn.GetName() + " lost.")
|
||||
}
|
||||
|
||||
func HandleConnect(s *Session) {
|
||||
fmt.Println(s.conn.GetName() + " connected.")
|
||||
}
|
||||
|
||||
func clientTest() {
|
||||
func TestClient(t *testing.T) {
|
||||
host := "127.0.0.1:9000"
|
||||
tcpAddr, _ := net.ResolveTCPAddr("tcp", host)
|
||||
|
||||
@@ -62,3 +46,21 @@ func clientTest() {
|
||||
|
||||
time.Sleep(time.Millisecond * 5000)
|
||||
}
|
||||
|
||||
func HandleMessage(session *Session, packet any) {
|
||||
fmt.Println("receive packet")
|
||||
fmt.Println(packet)
|
||||
|
||||
session.SendMessage(packet)
|
||||
}
|
||||
|
||||
func HandleDisconnect(session *Session, err error) {
|
||||
fmt.Println("disconnect")
|
||||
fmt.Println(session.sid)
|
||||
}
|
||||
|
||||
func HandleConnect(session *Session) {
|
||||
fmt.Println("connected.")
|
||||
fmt.Println(session.sid)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user