perf[go]: 简化golang的zfoo服务器

This commit is contained in:
godotg
2022-09-15 22:44:52 +08:00
parent de37816d5c
commit 0c4029ffad
7 changed files with 216 additions and 442 deletions
-153
View File
@@ -1,153 +0,0 @@
/*
* Copyright (C) 2020 The zfoo Authors
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package znet
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"time"
)
// Conn wrap net.Conn
type Conn struct {
sid string
rawConn net.Conn
sendCh chan []byte
messageCh chan any
done chan error
hbTimer *time.Timer
name string
hbInterval time.Duration
hbTimeout time.Duration
}
// GetName Get conn name
func (c *Conn) GetName() string {
return c.name
}
// NewConn create new conn
func NewConn(c net.Conn, hbInterval time.Duration, hbTimeout time.Duration) *Conn {
conn := &Conn{
rawConn: c,
sendCh: make(chan []byte, 100),
done: make(chan error),
messageCh: make(chan any, 100),
hbInterval: hbInterval,
hbTimeout: hbTimeout,
}
conn.name = c.RemoteAddr().String()
conn.hbTimer = time.NewTimer(conn.hbInterval)
if conn.hbInterval == 0 {
conn.hbTimer.Stop()
}
return conn
}
// Close close connection
func (c *Conn) Close() {
c.hbTimer.Stop()
c.rawConn.Close()
}
// SendMessage send message
func (c *Conn) SendMessage(msg *Packet) error {
var buffer = Encode(msg)
c.sendCh <- buffer.ToBytes()
return nil
}
// writeCoroutine write coroutine
func (c *Conn) writeCoroutine(ctx context.Context) {
hbData := make([]byte, 0)
for {
select {
case <-ctx.Done():
return
case pkt := <-c.sendCh:
if pkt == nil {
continue
}
if _, err := c.rawConn.Write(pkt); err != nil {
c.done <- err
}
case <-c.hbTimer.C:
hbMessage := NewMessage(MsgHeartbeat, hbData)
c.SendMessage(hbMessage)
// 设置心跳timer
if c.hbInterval > 0 {
c.hbTimer.Reset(c.hbInterval)
}
}
}
}
// readCoroutine read coroutine
func (c *Conn) readCoroutine(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
default:
// 设置超时
if c.hbInterval > 0 {
err := c.rawConn.SetReadDeadline(time.Now().Add(c.hbTimeout))
if err != nil {
c.done <- err
continue
}
}
// 读取长度
buf := make([]byte, 4)
_, err := io.ReadFull(c.rawConn, buf)
if err != nil {
c.done <- err
continue
}
bufReader := bytes.NewReader(buf)
var dataSize int32
err = binary.Read(bufReader, binary.BigEndian, &dataSize)
if err != nil {
c.done <- err
continue
}
// 读取数据
var bytes = make([]byte, dataSize)
_, err = io.ReadFull(c.rawConn, bytes)
if err != nil {
c.done <- err
continue
}
// 解码
var packet = Decode(bytes)
c.messageCh <- packet
}
}
}
-29
View File
@@ -1,29 +0,0 @@
/*
* Copyright (C) 2020 The zfoo Authors
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package znet
const (
// STUnknown Unknown
STUnknown = iota
// STInited Inited
STInited
// STRunning Running
STRunning
// STStop Stop
STStop
)
const (
// MsgHeartbeat heartbeat
MsgHeartbeat = iota
)
-38
View File
@@ -1,38 +0,0 @@
/*
* Copyright (C) 2020 The zfoo Authors
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package znet
import (
"fmt"
)
// Packet struct
type Packet struct {
length int32
protocolId int16
data []byte
}
// NewMessage create a new message
func NewMessage(protocolId int16, data []byte) *Packet {
msg := &Packet{
length: int32(len(data)) + 2 + 4,
protocolId: protocolId,
data: data,
}
return msg
}
func (msg *Packet) String() string {
return fmt.Sprintf("Size=%d ID=%d DataLen=%d", msg.length, msg.protocolId, len(msg.data))
}
+96
View File
@@ -0,0 +1,96 @@
/*
* Copyright (C) 2020 The zfoo Authors
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package znet
import (
"context"
"net"
"sync"
)
// Server struct
type Server struct {
onMessage func(*Session, any)
onConnect func(*Session)
onDisconnect func(*Session, error)
sessions *sync.Map
address string
listener net.Listener
}
// NewServer create a new socket service
func NewServer(addr string) *Server {
listen, _ := net.Listen("tcp", addr)
server := &Server{
sessions: &sync.Map{},
address: addr,
listener: listen,
}
return server
}
// Start Start socket service
func (s *Server) Start() {
ctx, cancel := context.WithCancel(context.Background())
defer func() {
cancel()
s.listener.Close()
}()
s.acceptHandler(ctx)
}
func (s *Server) acceptHandler(ctx context.Context) {
for {
conn, _ := s.listener.Accept()
go s.connectHandler(ctx, conn)
}
}
func (s *Server) connectHandler(ctx context.Context, c net.Conn) {
var session = NewSession(c)
s.sessions.Store(session.sid, session)
connctx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
session.Close()
s.sessions.Delete(session.sid)
}()
go session.readCoroutine(connctx)
go session.writeCoroutine(connctx)
if s.onConnect != nil {
s.onConnect(session)
}
for {
select {
case err := <-session.done:
if s.onDisconnect != nil {
s.onDisconnect(session, err)
}
return
case packet := <-session.messageCh:
if s.onMessage != nil {
s.onMessage(session, packet)
}
}
}
}
-194
View File
@@ -1,194 +0,0 @@
/*
* Copyright (C) 2020 The zfoo Authors
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/
package znet
import (
"context"
"errors"
"net"
"sync"
"time"
)
// SocketService struct
type SocketService struct {
onMessage func(*Session, any)
onConnect func(*Session)
onDisconnect func(*Session, error)
sessions *sync.Map
hbInterval time.Duration
hbTimeout time.Duration
laddr string
status int
listener net.Listener
stopCh chan error
}
// Server create a new socket service
func Server(laddr string) *SocketService {
listen, _ := net.Listen("tcp", laddr)
server := &SocketService{
sessions: &sync.Map{},
stopCh: make(chan error),
hbInterval: 0 * time.Second,
hbTimeout: 0 * time.Second,
laddr: laddr,
status: STInited,
listener: listen,
}
return server
}
// RegMessageHandler register message handler
func (s *SocketService) RegMessageHandler(handler func(*Session, any)) {
s.onMessage = handler
}
// RegConnectHandler register connect handler
func (s *SocketService) RegConnectHandler(handler func(*Session)) {
s.onConnect = handler
}
// RegDisconnectHandler register disconnect handler
func (s *SocketService) RegDisconnectHandler(handler func(*Session, error)) {
s.onDisconnect = handler
}
// Start Start socket service
func (s *SocketService) Start() {
s.status = STRunning
ctx, cancel := context.WithCancel(context.Background())
defer func() {
s.status = STStop
cancel()
s.listener.Close()
}()
go s.acceptHandler(ctx)
for {
select {
case <-s.stopCh:
return
}
}
}
func (s *SocketService) acceptHandler(ctx context.Context) {
for {
c, err := s.listener.Accept()
if err != nil {
s.stopCh <- err
return
}
go s.connectHandler(ctx, c)
}
}
func (s *SocketService) connectHandler(ctx context.Context, c net.Conn) {
conn := NewConn(c, s.hbInterval, s.hbTimeout)
session := NewSession(conn)
s.sessions.Store(session.sid, session)
connctx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
conn.Close()
s.sessions.Delete(session.sid)
}()
go conn.readCoroutine(connctx)
go conn.writeCoroutine(connctx)
if s.onConnect != nil {
s.onConnect(session)
}
for {
select {
case err := <-conn.done:
if s.onDisconnect != nil {
s.onDisconnect(session, err)
}
return
case packet := <-conn.messageCh:
if s.onMessage != nil {
s.onMessage(session, packet)
}
}
}
}
// GetStatus get socket service status
func (s *SocketService) GetStatus() int {
return s.status
}
// Stop stop socket service with reason
func (s *SocketService) Stop(reason string) {
s.stopCh <- errors.New(reason)
}
// SetHeartBeat set heart beat
func (s *SocketService) SetHeartBeat(hbInterval time.Duration, hbTimeout time.Duration) error {
if s.status == STRunning {
return errors.New("Can't set heart beat on service running")
}
s.hbInterval = hbInterval
s.hbTimeout = hbTimeout
return nil
}
// GetConnsCount get connect count
func (s *SocketService) GetConnsCount() int {
var count int
s.sessions.Range(func(k, v interface{}) bool {
count++
return true
})
return count
}
// Unicast Unicast with session ID
func (s *SocketService) Unicast(sid string, msg *Packet) {
v, ok := s.sessions.Load(sid)
if ok {
session := v.(*Session)
err := session.conn.SendMessage(msg)
if err != nil {
return
}
}
}
// Broadcast Broadcast to all connections
func (s *SocketService) Broadcast(msg *Packet) {
s.sessions.Range(func(k, v interface{}) bool {
s := v.(*Session)
if err := s.conn.SendMessage(msg); err != nil {
// log.Println(err)
}
return true
})
}
+95 -5
View File
@@ -11,26 +11,116 @@
*/
package znet
import "sync/atomic"
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"sync/atomic"
)
// Session struct
type Session struct {
sid uint64
uid uint64
conn *Conn
rawConn net.Conn
sendCh chan []byte
messageCh chan any
done chan error
}
var uuid uint64
// NewSession create a new session
func NewSession(conn *Conn) *Session {
func NewSession(conn net.Conn) *Session {
var suuid = atomic.AddUint64(&uuid, 1)
session := &Session{
sid: suuid,
uid: 0,// 可以为用户的id
conn: conn,
uid: 0, // 可以为用户的id
rawConn: conn,
sendCh: make(chan []byte, 100),
done: make(chan error),
messageCh: make(chan any, 100),
}
return session
}
// Close close connection
func (session *Session) Close() {
session.rawConn.Close()
}
// SendMessage send message
func (session *Session) SendMessage(msg any) error {
var buffer = Encode(msg)
session.sendCh <- buffer.ToBytes()
return nil
}
// writeCoroutine write coroutine
func (session *Session) writeCoroutine(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case pkt := <-session.sendCh:
if pkt == nil {
continue
}
if _, err := session.rawConn.Write(pkt); err != nil {
session.done <- err
}
}
}
}
// readCoroutine read coroutine
func (session *Session) readCoroutine(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
default:
// 读取长度
buf := make([]byte, 4)
_, err := io.ReadFull(session.rawConn, buf)
if err != nil {
session.done <- err
continue
}
bufReader := bytes.NewReader(buf)
var dataSize int32
err = binary.Read(bufReader, binary.BigEndian, &dataSize)
if err != nil {
session.done <- err
continue
}
// 读取数据
var bytes = make([]byte, dataSize)
_, err = io.ReadFull(session.rawConn, bytes)
if err != nil {
session.done <- err
continue
}
// 解码
var packet = Decode(bytes)
session.messageCh <- packet
}
}
}
+25 -23
View File
@@ -19,34 +19,18 @@ import (
"time"
)
func TestService(t *testing.T) {
host := "127.0.0.1:9000"
func TestServer(t *testing.T) {
var host = "127.0.0.1:9000"
server, _ := Server(host)
server.RegMessageHandler(HandleMessage)
server.RegConnectHandler(HandleConnect)
server.RegDisconnectHandler(HandleDisconnect)
var server = NewServer(host)
server.onMessage = HandleMessage
server.onConnect = HandleConnect
server.onDisconnect = HandleDisconnect
server.Start()
// clientTest()
}
func HandleMessage(s *Session, packet any) {
fmt.Println("receive packet")
fmt.Println(packet)
}
func HandleDisconnect(s *Session, err error) {
fmt.Println(s.conn.GetName() + " lost.")
}
func HandleConnect(s *Session) {
fmt.Println(s.conn.GetName() + " connected.")
}
func clientTest() {
func TestClient(t *testing.T) {
host := "127.0.0.1:9000"
tcpAddr, _ := net.ResolveTCPAddr("tcp", host)
@@ -62,3 +46,21 @@ func clientTest() {
time.Sleep(time.Millisecond * 5000)
}
func HandleMessage(session *Session, packet any) {
fmt.Println("receive packet")
fmt.Println(packet)
session.SendMessage(packet)
}
func HandleDisconnect(session *Session, err error) {
fmt.Println("disconnect")
fmt.Println(session.sid)
}
func HandleConnect(session *Session) {
fmt.Println("connected.")
fmt.Println(session.sid)
}