diff --git a/net/src/test/go/znet/conn.go b/net/src/test/go/znet/conn.go deleted file mode 100644 index e00726b4..00000000 --- a/net/src/test/go/znet/conn.go +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright (C) 2020 The zfoo Authors - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and limitations under the License. - */ - -package znet - -import ( - "bytes" - "context" - "encoding/binary" - "io" - "net" - "time" -) - -// Conn wrap net.Conn -type Conn struct { - sid string - rawConn net.Conn - sendCh chan []byte - messageCh chan any - done chan error - hbTimer *time.Timer - name string - hbInterval time.Duration - hbTimeout time.Duration -} - -// GetName Get conn name -func (c *Conn) GetName() string { - return c.name -} - -// NewConn create new conn -func NewConn(c net.Conn, hbInterval time.Duration, hbTimeout time.Duration) *Conn { - conn := &Conn{ - rawConn: c, - sendCh: make(chan []byte, 100), - done: make(chan error), - messageCh: make(chan any, 100), - hbInterval: hbInterval, - hbTimeout: hbTimeout, - } - - conn.name = c.RemoteAddr().String() - conn.hbTimer = time.NewTimer(conn.hbInterval) - - if conn.hbInterval == 0 { - conn.hbTimer.Stop() - } - - return conn -} - -// Close close connection -func (c *Conn) Close() { - c.hbTimer.Stop() - c.rawConn.Close() -} - -// SendMessage send message -func (c *Conn) SendMessage(msg *Packet) error { - var buffer = Encode(msg) - c.sendCh <- buffer.ToBytes() - return nil -} - -// writeCoroutine write coroutine -func (c *Conn) writeCoroutine(ctx context.Context) { - hbData := make([]byte, 0) - - for { - select { - case <-ctx.Done(): - return - - case pkt := <-c.sendCh: - - if pkt == nil { - continue - } - - if _, err := c.rawConn.Write(pkt); err != nil { - c.done <- err - } - - case <-c.hbTimer.C: - hbMessage := NewMessage(MsgHeartbeat, hbData) - c.SendMessage(hbMessage) - // 设置心跳timer - if c.hbInterval > 0 { - c.hbTimer.Reset(c.hbInterval) - } - } - } -} - -// readCoroutine read coroutine -func (c *Conn) readCoroutine(ctx context.Context) { - - for { - select { - case <-ctx.Done(): - return - - default: - // 设置超时 - if c.hbInterval > 0 { - err := c.rawConn.SetReadDeadline(time.Now().Add(c.hbTimeout)) - if err != nil { - c.done <- err - continue - } - } - // 读取长度 - buf := make([]byte, 4) - _, err := io.ReadFull(c.rawConn, buf) - if err != nil { - c.done <- err - continue - } - - bufReader := bytes.NewReader(buf) - - var dataSize int32 - err = binary.Read(bufReader, binary.BigEndian, &dataSize) - if err != nil { - c.done <- err - continue - } - - // 读取数据 - var bytes = make([]byte, dataSize) - _, err = io.ReadFull(c.rawConn, bytes) - if err != nil { - c.done <- err - continue - } - - // 解码 - var packet = Decode(bytes) - c.messageCh <- packet - } - } -} diff --git a/net/src/test/go/znet/def.go b/net/src/test/go/znet/def.go deleted file mode 100644 index 65e94808..00000000 --- a/net/src/test/go/znet/def.go +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (C) 2020 The zfoo Authors - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and limitations under the License. - */ - -package znet - -const ( - // STUnknown Unknown - STUnknown = iota - // STInited Inited - STInited - // STRunning Running - STRunning - // STStop Stop - STStop -) - -const ( - // MsgHeartbeat heartbeat - MsgHeartbeat = iota -) diff --git a/net/src/test/go/znet/packet.go b/net/src/test/go/znet/packet.go deleted file mode 100644 index 6158d8c1..00000000 --- a/net/src/test/go/znet/packet.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (C) 2020 The zfoo Authors - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and limitations under the License. - */ -package znet - -import ( - "fmt" -) - -// Packet struct -type Packet struct { - length int32 - protocolId int16 - data []byte -} - -// NewMessage create a new message -func NewMessage(protocolId int16, data []byte) *Packet { - msg := &Packet{ - length: int32(len(data)) + 2 + 4, - protocolId: protocolId, - data: data, - } - return msg -} - - -func (msg *Packet) String() string { - return fmt.Sprintf("Size=%d ID=%d DataLen=%d", msg.length, msg.protocolId, len(msg.data)) -} diff --git a/net/src/test/go/znet/server.go b/net/src/test/go/znet/server.go new file mode 100644 index 00000000..2709aa70 --- /dev/null +++ b/net/src/test/go/znet/server.go @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2020 The zfoo Authors + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ +package znet + +import ( + "context" + "net" + "sync" +) + +// Server struct +type Server struct { + onMessage func(*Session, any) + onConnect func(*Session) + onDisconnect func(*Session, error) + sessions *sync.Map + address string + listener net.Listener +} + +// NewServer create a new socket service +func NewServer(addr string) *Server { + listen, _ := net.Listen("tcp", addr) + server := &Server{ + sessions: &sync.Map{}, + address: addr, + listener: listen, + } + return server +} + + +// Start Start socket service +func (s *Server) Start() { + + ctx, cancel := context.WithCancel(context.Background()) + + defer func() { + cancel() + s.listener.Close() + }() + + s.acceptHandler(ctx) +} + +func (s *Server) acceptHandler(ctx context.Context) { + for { + conn, _ := s.listener.Accept() + go s.connectHandler(ctx, conn) + } +} + +func (s *Server) connectHandler(ctx context.Context, c net.Conn) { + var session = NewSession(c) + s.sessions.Store(session.sid, session) + + connctx, cancel := context.WithCancel(ctx) + + defer func() { + cancel() + session.Close() + s.sessions.Delete(session.sid) + }() + + go session.readCoroutine(connctx) + go session.writeCoroutine(connctx) + + if s.onConnect != nil { + s.onConnect(session) + } + + for { + select { + case err := <-session.done: + + if s.onDisconnect != nil { + s.onDisconnect(session, err) + } + return + + case packet := <-session.messageCh: + if s.onMessage != nil { + s.onMessage(session, packet) + } + } + } +} diff --git a/net/src/test/go/znet/service.go b/net/src/test/go/znet/service.go deleted file mode 100644 index d0f74118..00000000 --- a/net/src/test/go/znet/service.go +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright (C) 2020 The zfoo Authors - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and limitations under the License. - */ -package znet - -import ( - "context" - "errors" - "net" - "sync" - "time" -) - -// SocketService struct -type SocketService struct { - onMessage func(*Session, any) - onConnect func(*Session) - onDisconnect func(*Session, error) - sessions *sync.Map - hbInterval time.Duration - hbTimeout time.Duration - laddr string - status int - listener net.Listener - stopCh chan error -} - -// Server create a new socket service -func Server(laddr string) *SocketService { - - listen, _ := net.Listen("tcp", laddr) - - server := &SocketService{ - sessions: &sync.Map{}, - stopCh: make(chan error), - hbInterval: 0 * time.Second, - hbTimeout: 0 * time.Second, - laddr: laddr, - status: STInited, - listener: listen, - } - - return server -} - -// RegMessageHandler register message handler -func (s *SocketService) RegMessageHandler(handler func(*Session, any)) { - s.onMessage = handler -} - -// RegConnectHandler register connect handler -func (s *SocketService) RegConnectHandler(handler func(*Session)) { - s.onConnect = handler -} - -// RegDisconnectHandler register disconnect handler -func (s *SocketService) RegDisconnectHandler(handler func(*Session, error)) { - s.onDisconnect = handler -} - -// Start Start socket service -func (s *SocketService) Start() { - - s.status = STRunning - ctx, cancel := context.WithCancel(context.Background()) - - defer func() { - s.status = STStop - cancel() - s.listener.Close() - }() - - go s.acceptHandler(ctx) - - for { - select { - - case <-s.stopCh: - return - } - } -} - -func (s *SocketService) acceptHandler(ctx context.Context) { - for { - c, err := s.listener.Accept() - if err != nil { - s.stopCh <- err - return - } - - go s.connectHandler(ctx, c) - } -} - -func (s *SocketService) connectHandler(ctx context.Context, c net.Conn) { - conn := NewConn(c, s.hbInterval, s.hbTimeout) - session := NewSession(conn) - s.sessions.Store(session.sid, session) - - connctx, cancel := context.WithCancel(ctx) - - defer func() { - cancel() - conn.Close() - s.sessions.Delete(session.sid) - }() - - go conn.readCoroutine(connctx) - go conn.writeCoroutine(connctx) - - if s.onConnect != nil { - s.onConnect(session) - } - - for { - select { - case err := <-conn.done: - - if s.onDisconnect != nil { - s.onDisconnect(session, err) - } - return - - case packet := <-conn.messageCh: - if s.onMessage != nil { - s.onMessage(session, packet) - } - } - } -} - -// GetStatus get socket service status -func (s *SocketService) GetStatus() int { - return s.status -} - -// Stop stop socket service with reason -func (s *SocketService) Stop(reason string) { - s.stopCh <- errors.New(reason) -} - -// SetHeartBeat set heart beat -func (s *SocketService) SetHeartBeat(hbInterval time.Duration, hbTimeout time.Duration) error { - if s.status == STRunning { - return errors.New("Can't set heart beat on service running") - } - - s.hbInterval = hbInterval - s.hbTimeout = hbTimeout - - return nil -} - -// GetConnsCount get connect count -func (s *SocketService) GetConnsCount() int { - var count int - s.sessions.Range(func(k, v interface{}) bool { - count++ - return true - }) - return count -} - -// Unicast Unicast with session ID -func (s *SocketService) Unicast(sid string, msg *Packet) { - v, ok := s.sessions.Load(sid) - if ok { - session := v.(*Session) - err := session.conn.SendMessage(msg) - if err != nil { - return - } - } -} - -// Broadcast Broadcast to all connections -func (s *SocketService) Broadcast(msg *Packet) { - s.sessions.Range(func(k, v interface{}) bool { - s := v.(*Session) - if err := s.conn.SendMessage(msg); err != nil { - // log.Println(err) - } - return true - }) -} diff --git a/net/src/test/go/znet/session.go b/net/src/test/go/znet/session.go index b145c00f..c9553abc 100644 --- a/net/src/test/go/znet/session.go +++ b/net/src/test/go/znet/session.go @@ -11,26 +11,116 @@ */ package znet -import "sync/atomic" +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net" + "sync/atomic" +) // Session struct type Session struct { sid uint64 uid uint64 - conn *Conn + + rawConn net.Conn + sendCh chan []byte + messageCh chan any + done chan error } var uuid uint64 // NewSession create a new session -func NewSession(conn *Conn) *Session { +func NewSession(conn net.Conn) *Session { var suuid = atomic.AddUint64(&uuid, 1) session := &Session{ sid: suuid, - uid: 0,// 可以为用户的id - conn: conn, + uid: 0, // 可以为用户的id + + rawConn: conn, + sendCh: make(chan []byte, 100), + done: make(chan error), + messageCh: make(chan any, 100), } return session } + + + +// Close close connection +func (session *Session) Close() { + session.rawConn.Close() +} + +// SendMessage send message +func (session *Session) SendMessage(msg any) error { + var buffer = Encode(msg) + session.sendCh <- buffer.ToBytes() + return nil +} + +// writeCoroutine write coroutine +func (session *Session) writeCoroutine(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + + case pkt := <-session.sendCh: + + if pkt == nil { + continue + } + + if _, err := session.rawConn.Write(pkt); err != nil { + session.done <- err + } + } + } +} + +// readCoroutine read coroutine +func (session *Session) readCoroutine(ctx context.Context) { + + for { + select { + case <-ctx.Done(): + return + + default: + // 读取长度 + buf := make([]byte, 4) + _, err := io.ReadFull(session.rawConn, buf) + if err != nil { + session.done <- err + continue + } + + bufReader := bytes.NewReader(buf) + + var dataSize int32 + err = binary.Read(bufReader, binary.BigEndian, &dataSize) + if err != nil { + session.done <- err + continue + } + + // 读取数据 + var bytes = make([]byte, dataSize) + _, err = io.ReadFull(session.rawConn, bytes) + if err != nil { + session.done <- err + continue + } + + // 解码 + var packet = Decode(bytes) + session.messageCh <- packet + } + } +} diff --git a/net/src/test/go/znet/zne_test.go b/net/src/test/go/znet/zne_test.go index c557da07..e6c2e62a 100644 --- a/net/src/test/go/znet/zne_test.go +++ b/net/src/test/go/znet/zne_test.go @@ -19,34 +19,18 @@ import ( "time" ) -func TestService(t *testing.T) { - host := "127.0.0.1:9000" +func TestServer(t *testing.T) { + var host = "127.0.0.1:9000" - server, _ := Server(host) - server.RegMessageHandler(HandleMessage) - server.RegConnectHandler(HandleConnect) - server.RegDisconnectHandler(HandleDisconnect) + var server = NewServer(host) + server.onMessage = HandleMessage + server.onConnect = HandleConnect + server.onDisconnect = HandleDisconnect server.Start() - - // clientTest() } - -func HandleMessage(s *Session, packet any) { - fmt.Println("receive packet") - fmt.Println(packet) -} - -func HandleDisconnect(s *Session, err error) { - fmt.Println(s.conn.GetName() + " lost.") -} - -func HandleConnect(s *Session) { - fmt.Println(s.conn.GetName() + " connected.") -} - -func clientTest() { +func TestClient(t *testing.T) { host := "127.0.0.1:9000" tcpAddr, _ := net.ResolveTCPAddr("tcp", host) @@ -62,3 +46,21 @@ func clientTest() { time.Sleep(time.Millisecond * 5000) } + +func HandleMessage(session *Session, packet any) { + fmt.Println("receive packet") + fmt.Println(packet) + + session.SendMessage(packet) +} + +func HandleDisconnect(session *Session, err error) { + fmt.Println("disconnect") + fmt.Println(session.sid) +} + +func HandleConnect(session *Session) { + fmt.Println("connected.") + fmt.Println(session.sid) +} +