From 9d60973ea1662f7b77ccc9f2c3256e1463ba39fe Mon Sep 17 00:00:00 2001 From: Martin Geno Date: Fri, 27 Oct 2017 19:40:42 +0200 Subject: [PATCH] [TEST] improve websocket --- database/main.go | 8 +++--- websocket/client.go | 14 ++++++---- websocket/client_test.go | 46 +++++++++++++++++++++++++++++++ websocket/msg_test.go | 34 +++++++++++++++++++++-- websocket/server.go | 17 ++++++++---- websocket/server_test.go | 33 ++++++++++++++++++++++ websocket/session.go | 28 +++++++++++-------- websocket/session_test.go | 58 +++++++++++++++++++++++++++++++++++---- worker/main_test.go | 5 +++- 9 files changed, 207 insertions(+), 36 deletions(-) create mode 100644 websocket/client_test.go create mode 100644 websocket/server_test.go diff --git a/database/main.go b/database/main.go index b5e5b6d..02f08e8 100644 --- a/database/main.go +++ b/database/main.go @@ -3,11 +3,11 @@ package database import ( "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/mysql" _ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" - _ "github.com/jinzhu/gorm/dialects/mysql" - "github.com/genofire/golang-lib/log" + log "github.com/sirupsen/logrus" ) // Database connection for writing purposes @@ -36,7 +36,7 @@ type Config struct { // Function to open a database and set the given configuration func Open(c Config) (err error) { - writeLog := log.Log.WithField("db", "write") + writeLog := log.WithField("db", "write") config = &c Write, err = gorm.Open(config.Type, config.Connection) if err != nil { @@ -48,7 +48,7 @@ func Open(c Config) (err error) { Write.Callback().Create().Remove("gorm:update_time_stamp") Write.Callback().Update().Remove("gorm:update_time_stamp") if len(config.ReadConnection) > 0 { - readLog := log.Log.WithField("db", "read") + readLog := log.WithField("db", "read") Read, err = gorm.Open(config.Type, config.ReadConnection) if err != nil { return diff --git a/websocket/client.go b/websocket/client.go index 904fcd6..91a8ce6 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -3,6 +3,7 @@ package websocket import ( "io" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/gorilla/websocket" @@ -11,6 +12,7 @@ import ( const channelBufSize = 100 type Client struct { + id uuid.UUID server *Server ws *websocket.Conn out chan *Message @@ -27,6 +29,7 @@ func NewClient(s *Server, ws *websocket.Conn) *Client { return &Client{ server: s, ws: ws, + id: uuid.New(), // fallback id (for testing) out: make(chan *Message, channelBufSize), writeQuit: make(chan bool), readQuit: make(chan bool), @@ -34,14 +37,16 @@ func NewClient(s *Server, ws *websocket.Conn) *Client { } func (c *Client) GetID() string { - return c.ws.RemoteAddr().String() + if c.ws != nil { + return c.ws.RemoteAddr().String() + } + return c.id.String() } func (c *Client) Write(msg *Message) { select { case c.out <- msg: default: - c.server.SessionManager.Remove(c) c.server.DelClient(c) c.Close() } @@ -57,13 +62,12 @@ func (c *Client) Close() { func (c *Client) Listen() { go c.listenWrite() c.server.AddClient(c) - c.server.SessionManager.Init(c) c.listenRead() } func (c *Client) handleInput(msg *Message) { msg.From = c - if c.server.SessionManager.HandleMessage(msg) { + if sm := c.server.sessionManager; sm != nil && sm.HandleMessage(msg) { return } if ok, err := msg.Validate(); ok { @@ -81,7 +85,6 @@ func (c *Client) listenWrite() { websocket.WriteJSON(c.ws, msg) case <-c.writeQuit: - c.server.SessionManager.Remove(c) c.server.DelClient(c) close(c.out) close(c.writeQuit) @@ -96,7 +99,6 @@ func (c *Client) listenRead() { select { case <-c.readQuit: - c.server.SessionManager.Remove(c) c.server.DelClient(c) close(c.readQuit) return diff --git a/websocket/client_test.go b/websocket/client_test.go new file mode 100644 index 0000000..62b22a7 --- /dev/null +++ b/websocket/client_test.go @@ -0,0 +1,46 @@ +package websocket + +import ( + "testing" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func TestClient(t *testing.T) { + assert := assert.New(t) + + chanMsg := make(chan *Message) + sm := NewSessionManager() + + srv := NewServer(chanMsg, sm) + + assert.Panics(func() { + NewClient(srv, nil) + }) + + client := NewClient(srv, &websocket.Conn{}) + assert.NotNil(client) + + client = &Client{ + server: srv, + id: uuid.New(), + out: make(chan *Message, channelBufSize), + writeQuit: make(chan bool), + readQuit: make(chan bool), + } + + client.handleInput(&Message{}) + + go client.handleInput(&Message{Subject: "a"}) + msg := <-chanMsg + assert.Equal("a", msg.Subject) + + // msg catched by sessionManager -> not read from chanMsg needed + client.handleInput(&Message{ + ID: uuid.New(), + Subject: SessionMessageInit, + }) + +} diff --git a/websocket/msg_test.go b/websocket/msg_test.go index 943bcf2..efdf046 100644 --- a/websocket/msg_test.go +++ b/websocket/msg_test.go @@ -1,18 +1,18 @@ package websocket - import ( "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) func TestMSGValidate(t *testing.T) { assert := assert.New(t) - + msg := &Message{} assert.False(msg.Validate()) - + msg.Subject = "login" assert.False(msg.Validate()) @@ -22,3 +22,31 @@ func TestMSGValidate(t *testing.T) { msg.Subject = "" assert.False(msg.Validate()) } + +func TestMSGAnswer(t *testing.T) { + assert := assert.New(t) + + out := make(chan *Message, channelBufSize) + client := &Client{ + id: uuid.New(), + out: out, + writeQuit: make(chan bool), + readQuit: make(chan bool), + } + + conversationID := uuid.New() + + msg := &Message{ + From: client, + ID: conversationID, + } + + go msg.Answer("hi", nil) + msg = <-out + + assert.Equal(conversationID, msg.ID) + assert.Equal(uuid.Nil, msg.Session) + assert.Equal(client, msg.From) + assert.Equal("hi", msg.Subject) + assert.Nil(msg.Body) +} diff --git a/websocket/server.go b/websocket/server.go index 60262bf..3a9d70f 100644 --- a/websocket/server.go +++ b/websocket/server.go @@ -12,15 +12,15 @@ type Server struct { msgChanIn chan *Message clients map[string]*Client clientsMutex sync.Mutex - SessionManager *SessionManager + sessionManager *SessionManager upgrader websocket.Upgrader } -func NewServer(msgChanIn chan *Message) *Server { +func NewServer(msgChanIn chan *Message, sessionManager *SessionManager) *Server { return &Server{ clients: make(map[string]*Client), msgChanIn: msgChanIn, - SessionManager: NewSessionManager(), + sessionManager: sessionManager, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -31,7 +31,7 @@ func NewServer(msgChanIn chan *Message) *Server { func (s *Server) Handler(w http.ResponseWriter, r *http.Request) { conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { - log.Println(err) + log.Info(err) return } client := NewClient(s, conn) @@ -45,8 +45,11 @@ func (s *Server) AddClient(c *Client) { } if id := c.GetID(); id != "" { s.clientsMutex.Lock() - defer s.clientsMutex.Unlock() s.clients[id] = c + s.clientsMutex.Unlock() + if s.sessionManager != nil { + s.sessionManager.Init(c) + } } } @@ -58,6 +61,8 @@ func (s *Server) DelClient(c *Client) { s.clientsMutex.Lock() delete(s.clients, id) s.clientsMutex.Unlock() - s.SessionManager.Remove(c) + if s.sessionManager != nil { + s.sessionManager.Remove(c) + } } } diff --git a/websocket/server_test.go b/websocket/server_test.go new file mode 100644 index 0000000..fa2b972 --- /dev/null +++ b/websocket/server_test.go @@ -0,0 +1,33 @@ +package websocket + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestServer(t *testing.T) { + assert := assert.New(t) + + srv := NewServer(nil, NewSessionManager()) + assert.NotNil(srv) + + req, _ := http.NewRequest("GET", "url", nil) + w := httptest.NewRecorder() + srv.Handler(w, req) + + out := make(chan *Message) + c := &Client{ + out: out, + server: srv, + } + srv.AddClient(nil) + go srv.AddClient(c) + msg := <-out + assert.Equal(SessionMessageInit, msg.Subject) + + srv.DelClient(nil) + srv.DelClient(c) +} diff --git a/websocket/session.go b/websocket/session.go index 94dde0e..e855e1e 100644 --- a/websocket/session.go +++ b/websocket/session.go @@ -40,38 +40,44 @@ func (s *SessionManager) HandleMessage(msg *Message) bool { s.clientToSession[id] = msg.ID s.sessionToClient[msg.ID] = list return true - } else { + } else if msg.From != nil { id := msg.From.GetID() msg.Session = s.clientToSession[id] } return false } -func (s *SessionManager) Remove(c *Client) { + +// Remove clients from SessionManagerer +// - 1. result: clients removed from session manager +// - 2. result: session closed and all clients removed +func (s *SessionManager) Remove(c *Client) (client bool, session bool) { if c == nil { - return + return false, false } if id := c.GetID(); id != "" { session := s.clientToSession[id] + defer delete(s.clientToSession, id) if session != uuid.Nil { s.Lock() defer s.Unlock() - list := s.sessionToClient[session] - delete(list, id) - if len(list) > 0 { - s.sessionToClient[session] = list + clients := s.sessionToClient[session] + delete(clients, id) + if len(clients) > 0 { + s.sessionToClient[session] = clients + return true, false } else { delete(s.sessionToClient, session) + return true, true } } - delete(s.clientToSession, id) } - + return false, false } func (s *SessionManager) Send(id uuid.UUID, msg *Message) { - session := s.sessionToClient[id] - for _, c := range session { + clients := s.sessionToClient[id] + for _, c := range clients { c.Write(msg) } } diff --git a/websocket/session_test.go b/websocket/session_test.go index efbb0e2..747082e 100644 --- a/websocket/session_test.go +++ b/websocket/session_test.go @@ -3,8 +3,6 @@ package websocket import ( "testing" - "github.com/gorilla/websocket" - "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -17,13 +15,13 @@ func TestSessionManager(t *testing.T) { out := make(chan *Message, channelBufSize) client := &Client{ + id: uuid.New(), out: out, writeQuit: make(chan bool), readQuit: make(chan bool), - ws: &websocket.Conn{}, } - session.Init(client) + go session.Init(client) msg := <-out assert.Equal(SessionMessageInit, msg.Subject) @@ -35,9 +33,59 @@ func TestSessionManager(t *testing.T) { assert.False(result) result = session.HandleMessage(&Message{ - ID: uuid.New(), + ID: uuid.New(), + From: client, + }) + assert.False(result) + + sessionID := uuid.New() + result = session.HandleMessage(&Message{ + ID: sessionID, From: client, Subject: SessionMessageInit, }) assert.True(result) + + go session.Send(sessionID, &Message{ + Subject: "some trash", + }) + msg = <-out + assert.Equal("some trash", msg.Subject) + + // a client need to disconnected + c, s := session.Remove(nil) + assert.False(c) + assert.False(s) + + out2 := make(chan *Message, channelBufSize) + client2 := &Client{ + id: uuid.New(), + out: out2, + writeQuit: make(chan bool), + readQuit: make(chan bool), + } + + go session.Init(client2) + msg = <-out2 + result = session.HandleMessage(&Message{ + ID: sessionID, + From: client2, + Subject: SessionMessageInit, + }) + assert.True(result) + + // remove first client of session + c, s = session.Remove(client) + assert.True(c) + assert.False(s) + + // remove last client of session + c, s = session.Remove(client2) + assert.True(c) + assert.True(s) + + // all client disconnected already + c, s = session.Remove(client2) + assert.False(c) + assert.False(s) } diff --git a/worker/main_test.go b/worker/main_test.go index 6dedd08..f6be1e7 100644 --- a/worker/main_test.go +++ b/worker/main_test.go @@ -20,6 +20,9 @@ func TestWorker(t *testing.T) { go w.Start() time.Sleep(time.Duration(18) * time.Millisecond) w.Close() - + time.Sleep(time.Duration(18) * time.Millisecond) assert.Equal(3, runtime) + assert.Panics(func() { + w.Close() + }) }