[TEST] improve websocket

This commit is contained in:
Martin Geno 2017-10-27 19:40:42 +02:00
parent 80290a8a17
commit 9d60973ea1
No known key found for this signature in database
GPG Key ID: F0D39A37E925E941
9 changed files with 207 additions and 36 deletions

View File

@ -3,11 +3,11 @@ package database
import ( import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite" _ "github.com/jinzhu/gorm/dialects/sqlite"
_ "github.com/jinzhu/gorm/dialects/mysql"
"github.com/genofire/golang-lib/log" log "github.com/sirupsen/logrus"
) )
// Database connection for writing purposes // Database connection for writing purposes
@ -36,7 +36,7 @@ type Config struct {
// Function to open a database and set the given configuration // Function to open a database and set the given configuration
func Open(c Config) (err error) { func Open(c Config) (err error) {
writeLog := log.Log.WithField("db", "write") writeLog := log.WithField("db", "write")
config = &c config = &c
Write, err = gorm.Open(config.Type, config.Connection) Write, err = gorm.Open(config.Type, config.Connection)
if err != nil { if err != nil {
@ -48,7 +48,7 @@ func Open(c Config) (err error) {
Write.Callback().Create().Remove("gorm:update_time_stamp") Write.Callback().Create().Remove("gorm:update_time_stamp")
Write.Callback().Update().Remove("gorm:update_time_stamp") Write.Callback().Update().Remove("gorm:update_time_stamp")
if len(config.ReadConnection) > 0 { if len(config.ReadConnection) > 0 {
readLog := log.Log.WithField("db", "read") readLog := log.WithField("db", "read")
Read, err = gorm.Open(config.Type, config.ReadConnection) Read, err = gorm.Open(config.Type, config.ReadConnection)
if err != nil { if err != nil {
return return

View File

@ -3,6 +3,7 @@ package websocket
import ( import (
"io" "io"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -11,6 +12,7 @@ import (
const channelBufSize = 100 const channelBufSize = 100
type Client struct { type Client struct {
id uuid.UUID
server *Server server *Server
ws *websocket.Conn ws *websocket.Conn
out chan *Message out chan *Message
@ -27,6 +29,7 @@ func NewClient(s *Server, ws *websocket.Conn) *Client {
return &Client{ return &Client{
server: s, server: s,
ws: ws, ws: ws,
id: uuid.New(), // fallback id (for testing)
out: make(chan *Message, channelBufSize), out: make(chan *Message, channelBufSize),
writeQuit: make(chan bool), writeQuit: make(chan bool),
readQuit: make(chan bool), readQuit: make(chan bool),
@ -34,14 +37,16 @@ func NewClient(s *Server, ws *websocket.Conn) *Client {
} }
func (c *Client) GetID() string { func (c *Client) GetID() string {
return c.ws.RemoteAddr().String() if c.ws != nil {
return c.ws.RemoteAddr().String()
}
return c.id.String()
} }
func (c *Client) Write(msg *Message) { func (c *Client) Write(msg *Message) {
select { select {
case c.out <- msg: case c.out <- msg:
default: default:
c.server.SessionManager.Remove(c)
c.server.DelClient(c) c.server.DelClient(c)
c.Close() c.Close()
} }
@ -57,13 +62,12 @@ func (c *Client) Close() {
func (c *Client) Listen() { func (c *Client) Listen() {
go c.listenWrite() go c.listenWrite()
c.server.AddClient(c) c.server.AddClient(c)
c.server.SessionManager.Init(c)
c.listenRead() c.listenRead()
} }
func (c *Client) handleInput(msg *Message) { func (c *Client) handleInput(msg *Message) {
msg.From = c msg.From = c
if c.server.SessionManager.HandleMessage(msg) { if sm := c.server.sessionManager; sm != nil && sm.HandleMessage(msg) {
return return
} }
if ok, err := msg.Validate(); ok { if ok, err := msg.Validate(); ok {
@ -81,7 +85,6 @@ func (c *Client) listenWrite() {
websocket.WriteJSON(c.ws, msg) websocket.WriteJSON(c.ws, msg)
case <-c.writeQuit: case <-c.writeQuit:
c.server.SessionManager.Remove(c)
c.server.DelClient(c) c.server.DelClient(c)
close(c.out) close(c.out)
close(c.writeQuit) close(c.writeQuit)
@ -96,7 +99,6 @@ func (c *Client) listenRead() {
select { select {
case <-c.readQuit: case <-c.readQuit:
c.server.SessionManager.Remove(c)
c.server.DelClient(c) c.server.DelClient(c)
close(c.readQuit) close(c.readQuit)
return return

46
websocket/client_test.go Normal file
View File

@ -0,0 +1,46 @@
package websocket
import (
"testing"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)
func TestClient(t *testing.T) {
assert := assert.New(t)
chanMsg := make(chan *Message)
sm := NewSessionManager()
srv := NewServer(chanMsg, sm)
assert.Panics(func() {
NewClient(srv, nil)
})
client := NewClient(srv, &websocket.Conn{})
assert.NotNil(client)
client = &Client{
server: srv,
id: uuid.New(),
out: make(chan *Message, channelBufSize),
writeQuit: make(chan bool),
readQuit: make(chan bool),
}
client.handleInput(&Message{})
go client.handleInput(&Message{Subject: "a"})
msg := <-chanMsg
assert.Equal("a", msg.Subject)
// msg catched by sessionManager -> not read from chanMsg needed
client.handleInput(&Message{
ID: uuid.New(),
Subject: SessionMessageInit,
})
}

View File

@ -1,18 +1,18 @@
package websocket package websocket
import ( import (
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestMSGValidate(t *testing.T) { func TestMSGValidate(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
msg := &Message{} msg := &Message{}
assert.False(msg.Validate()) assert.False(msg.Validate())
msg.Subject = "login" msg.Subject = "login"
assert.False(msg.Validate()) assert.False(msg.Validate())
@ -22,3 +22,31 @@ func TestMSGValidate(t *testing.T) {
msg.Subject = "" msg.Subject = ""
assert.False(msg.Validate()) assert.False(msg.Validate())
} }
func TestMSGAnswer(t *testing.T) {
assert := assert.New(t)
out := make(chan *Message, channelBufSize)
client := &Client{
id: uuid.New(),
out: out,
writeQuit: make(chan bool),
readQuit: make(chan bool),
}
conversationID := uuid.New()
msg := &Message{
From: client,
ID: conversationID,
}
go msg.Answer("hi", nil)
msg = <-out
assert.Equal(conversationID, msg.ID)
assert.Equal(uuid.Nil, msg.Session)
assert.Equal(client, msg.From)
assert.Equal("hi", msg.Subject)
assert.Nil(msg.Body)
}

View File

@ -12,15 +12,15 @@ type Server struct {
msgChanIn chan *Message msgChanIn chan *Message
clients map[string]*Client clients map[string]*Client
clientsMutex sync.Mutex clientsMutex sync.Mutex
SessionManager *SessionManager sessionManager *SessionManager
upgrader websocket.Upgrader upgrader websocket.Upgrader
} }
func NewServer(msgChanIn chan *Message) *Server { func NewServer(msgChanIn chan *Message, sessionManager *SessionManager) *Server {
return &Server{ return &Server{
clients: make(map[string]*Client), clients: make(map[string]*Client),
msgChanIn: msgChanIn, msgChanIn: msgChanIn,
SessionManager: NewSessionManager(), sessionManager: sessionManager,
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
@ -31,7 +31,7 @@ func NewServer(msgChanIn chan *Message) *Server {
func (s *Server) Handler(w http.ResponseWriter, r *http.Request) { func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil) conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Println(err) log.Info(err)
return return
} }
client := NewClient(s, conn) client := NewClient(s, conn)
@ -45,8 +45,11 @@ func (s *Server) AddClient(c *Client) {
} }
if id := c.GetID(); id != "" { if id := c.GetID(); id != "" {
s.clientsMutex.Lock() s.clientsMutex.Lock()
defer s.clientsMutex.Unlock()
s.clients[id] = c s.clients[id] = c
s.clientsMutex.Unlock()
if s.sessionManager != nil {
s.sessionManager.Init(c)
}
} }
} }
@ -58,6 +61,8 @@ func (s *Server) DelClient(c *Client) {
s.clientsMutex.Lock() s.clientsMutex.Lock()
delete(s.clients, id) delete(s.clients, id)
s.clientsMutex.Unlock() s.clientsMutex.Unlock()
s.SessionManager.Remove(c) if s.sessionManager != nil {
s.sessionManager.Remove(c)
}
} }
} }

33
websocket/server_test.go Normal file
View File

@ -0,0 +1,33 @@
package websocket
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestServer(t *testing.T) {
assert := assert.New(t)
srv := NewServer(nil, NewSessionManager())
assert.NotNil(srv)
req, _ := http.NewRequest("GET", "url", nil)
w := httptest.NewRecorder()
srv.Handler(w, req)
out := make(chan *Message)
c := &Client{
out: out,
server: srv,
}
srv.AddClient(nil)
go srv.AddClient(c)
msg := <-out
assert.Equal(SessionMessageInit, msg.Subject)
srv.DelClient(nil)
srv.DelClient(c)
}

View File

@ -40,38 +40,44 @@ func (s *SessionManager) HandleMessage(msg *Message) bool {
s.clientToSession[id] = msg.ID s.clientToSession[id] = msg.ID
s.sessionToClient[msg.ID] = list s.sessionToClient[msg.ID] = list
return true return true
} else { } else if msg.From != nil {
id := msg.From.GetID() id := msg.From.GetID()
msg.Session = s.clientToSession[id] msg.Session = s.clientToSession[id]
} }
return false return false
} }
func (s *SessionManager) Remove(c *Client) {
// Remove clients from SessionManagerer
// - 1. result: clients removed from session manager
// - 2. result: session closed and all clients removed
func (s *SessionManager) Remove(c *Client) (client bool, session bool) {
if c == nil { if c == nil {
return return false, false
} }
if id := c.GetID(); id != "" { if id := c.GetID(); id != "" {
session := s.clientToSession[id] session := s.clientToSession[id]
defer delete(s.clientToSession, id)
if session != uuid.Nil { if session != uuid.Nil {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
list := s.sessionToClient[session] clients := s.sessionToClient[session]
delete(list, id) delete(clients, id)
if len(list) > 0 { if len(clients) > 0 {
s.sessionToClient[session] = list s.sessionToClient[session] = clients
return true, false
} else { } else {
delete(s.sessionToClient, session) delete(s.sessionToClient, session)
return true, true
} }
} }
delete(s.clientToSession, id)
} }
return false, false
} }
func (s *SessionManager) Send(id uuid.UUID, msg *Message) { func (s *SessionManager) Send(id uuid.UUID, msg *Message) {
session := s.sessionToClient[id] clients := s.sessionToClient[id]
for _, c := range session { for _, c := range clients {
c.Write(msg) c.Write(msg)
} }
} }

View File

@ -3,8 +3,6 @@ package websocket
import ( import (
"testing" "testing"
"github.com/gorilla/websocket"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -17,13 +15,13 @@ func TestSessionManager(t *testing.T) {
out := make(chan *Message, channelBufSize) out := make(chan *Message, channelBufSize)
client := &Client{ client := &Client{
id: uuid.New(),
out: out, out: out,
writeQuit: make(chan bool), writeQuit: make(chan bool),
readQuit: make(chan bool), readQuit: make(chan bool),
ws: &websocket.Conn{},
} }
session.Init(client) go session.Init(client)
msg := <-out msg := <-out
assert.Equal(SessionMessageInit, msg.Subject) assert.Equal(SessionMessageInit, msg.Subject)
@ -35,9 +33,59 @@ func TestSessionManager(t *testing.T) {
assert.False(result) assert.False(result)
result = session.HandleMessage(&Message{ result = session.HandleMessage(&Message{
ID: uuid.New(), ID: uuid.New(),
From: client,
})
assert.False(result)
sessionID := uuid.New()
result = session.HandleMessage(&Message{
ID: sessionID,
From: client, From: client,
Subject: SessionMessageInit, Subject: SessionMessageInit,
}) })
assert.True(result) assert.True(result)
go session.Send(sessionID, &Message{
Subject: "some trash",
})
msg = <-out
assert.Equal("some trash", msg.Subject)
// a client need to disconnected
c, s := session.Remove(nil)
assert.False(c)
assert.False(s)
out2 := make(chan *Message, channelBufSize)
client2 := &Client{
id: uuid.New(),
out: out2,
writeQuit: make(chan bool),
readQuit: make(chan bool),
}
go session.Init(client2)
msg = <-out2
result = session.HandleMessage(&Message{
ID: sessionID,
From: client2,
Subject: SessionMessageInit,
})
assert.True(result)
// remove first client of session
c, s = session.Remove(client)
assert.True(c)
assert.False(s)
// remove last client of session
c, s = session.Remove(client2)
assert.True(c)
assert.True(s)
// all client disconnected already
c, s = session.Remove(client2)
assert.False(c)
assert.False(s)
} }

View File

@ -20,6 +20,9 @@ func TestWorker(t *testing.T) {
go w.Start() go w.Start()
time.Sleep(time.Duration(18) * time.Millisecond) time.Sleep(time.Duration(18) * time.Millisecond)
w.Close() w.Close()
time.Sleep(time.Duration(18) * time.Millisecond)
assert.Equal(3, runtime) assert.Equal(3, runtime)
assert.Panics(func() {
w.Close()
})
} }