[TEST] improve websocket
This commit is contained in:
parent
80290a8a17
commit
9d60973ea1
|
@ -3,11 +3,11 @@ package database
|
|||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
|
||||
"github.com/genofire/golang-lib/log"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Database connection for writing purposes
|
||||
|
@ -36,7 +36,7 @@ type Config struct {
|
|||
|
||||
// Function to open a database and set the given configuration
|
||||
func Open(c Config) (err error) {
|
||||
writeLog := log.Log.WithField("db", "write")
|
||||
writeLog := log.WithField("db", "write")
|
||||
config = &c
|
||||
Write, err = gorm.Open(config.Type, config.Connection)
|
||||
if err != nil {
|
||||
|
@ -48,7 +48,7 @@ func Open(c Config) (err error) {
|
|||
Write.Callback().Create().Remove("gorm:update_time_stamp")
|
||||
Write.Callback().Update().Remove("gorm:update_time_stamp")
|
||||
if len(config.ReadConnection) > 0 {
|
||||
readLog := log.Log.WithField("db", "read")
|
||||
readLog := log.WithField("db", "read")
|
||||
Read, err = gorm.Open(config.Type, config.ReadConnection)
|
||||
if err != nil {
|
||||
return
|
||||
|
|
|
@ -3,6 +3,7 @@ package websocket
|
|||
import (
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
@ -11,6 +12,7 @@ import (
|
|||
const channelBufSize = 100
|
||||
|
||||
type Client struct {
|
||||
id uuid.UUID
|
||||
server *Server
|
||||
ws *websocket.Conn
|
||||
out chan *Message
|
||||
|
@ -27,6 +29,7 @@ func NewClient(s *Server, ws *websocket.Conn) *Client {
|
|||
return &Client{
|
||||
server: s,
|
||||
ws: ws,
|
||||
id: uuid.New(), // fallback id (for testing)
|
||||
out: make(chan *Message, channelBufSize),
|
||||
writeQuit: make(chan bool),
|
||||
readQuit: make(chan bool),
|
||||
|
@ -34,14 +37,16 @@ func NewClient(s *Server, ws *websocket.Conn) *Client {
|
|||
}
|
||||
|
||||
func (c *Client) GetID() string {
|
||||
return c.ws.RemoteAddr().String()
|
||||
if c.ws != nil {
|
||||
return c.ws.RemoteAddr().String()
|
||||
}
|
||||
return c.id.String()
|
||||
}
|
||||
|
||||
func (c *Client) Write(msg *Message) {
|
||||
select {
|
||||
case c.out <- msg:
|
||||
default:
|
||||
c.server.SessionManager.Remove(c)
|
||||
c.server.DelClient(c)
|
||||
c.Close()
|
||||
}
|
||||
|
@ -57,13 +62,12 @@ func (c *Client) Close() {
|
|||
func (c *Client) Listen() {
|
||||
go c.listenWrite()
|
||||
c.server.AddClient(c)
|
||||
c.server.SessionManager.Init(c)
|
||||
c.listenRead()
|
||||
}
|
||||
|
||||
func (c *Client) handleInput(msg *Message) {
|
||||
msg.From = c
|
||||
if c.server.SessionManager.HandleMessage(msg) {
|
||||
if sm := c.server.sessionManager; sm != nil && sm.HandleMessage(msg) {
|
||||
return
|
||||
}
|
||||
if ok, err := msg.Validate(); ok {
|
||||
|
@ -81,7 +85,6 @@ func (c *Client) listenWrite() {
|
|||
websocket.WriteJSON(c.ws, msg)
|
||||
|
||||
case <-c.writeQuit:
|
||||
c.server.SessionManager.Remove(c)
|
||||
c.server.DelClient(c)
|
||||
close(c.out)
|
||||
close(c.writeQuit)
|
||||
|
@ -96,7 +99,6 @@ func (c *Client) listenRead() {
|
|||
select {
|
||||
|
||||
case <-c.readQuit:
|
||||
c.server.SessionManager.Remove(c)
|
||||
c.server.DelClient(c)
|
||||
close(c.readQuit)
|
||||
return
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
chanMsg := make(chan *Message)
|
||||
sm := NewSessionManager()
|
||||
|
||||
srv := NewServer(chanMsg, sm)
|
||||
|
||||
assert.Panics(func() {
|
||||
NewClient(srv, nil)
|
||||
})
|
||||
|
||||
client := NewClient(srv, &websocket.Conn{})
|
||||
assert.NotNil(client)
|
||||
|
||||
client = &Client{
|
||||
server: srv,
|
||||
id: uuid.New(),
|
||||
out: make(chan *Message, channelBufSize),
|
||||
writeQuit: make(chan bool),
|
||||
readQuit: make(chan bool),
|
||||
}
|
||||
|
||||
client.handleInput(&Message{})
|
||||
|
||||
go client.handleInput(&Message{Subject: "a"})
|
||||
msg := <-chanMsg
|
||||
assert.Equal("a", msg.Subject)
|
||||
|
||||
// msg catched by sessionManager -> not read from chanMsg needed
|
||||
client.handleInput(&Message{
|
||||
ID: uuid.New(),
|
||||
Subject: SessionMessageInit,
|
||||
})
|
||||
|
||||
}
|
|
@ -1,18 +1,18 @@
|
|||
package websocket
|
||||
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMSGValidate(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
|
||||
msg := &Message{}
|
||||
assert.False(msg.Validate())
|
||||
|
||||
|
||||
msg.Subject = "login"
|
||||
assert.False(msg.Validate())
|
||||
|
||||
|
@ -22,3 +22,31 @@ func TestMSGValidate(t *testing.T) {
|
|||
msg.Subject = ""
|
||||
assert.False(msg.Validate())
|
||||
}
|
||||
|
||||
func TestMSGAnswer(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
out := make(chan *Message, channelBufSize)
|
||||
client := &Client{
|
||||
id: uuid.New(),
|
||||
out: out,
|
||||
writeQuit: make(chan bool),
|
||||
readQuit: make(chan bool),
|
||||
}
|
||||
|
||||
conversationID := uuid.New()
|
||||
|
||||
msg := &Message{
|
||||
From: client,
|
||||
ID: conversationID,
|
||||
}
|
||||
|
||||
go msg.Answer("hi", nil)
|
||||
msg = <-out
|
||||
|
||||
assert.Equal(conversationID, msg.ID)
|
||||
assert.Equal(uuid.Nil, msg.Session)
|
||||
assert.Equal(client, msg.From)
|
||||
assert.Equal("hi", msg.Subject)
|
||||
assert.Nil(msg.Body)
|
||||
}
|
||||
|
|
|
@ -12,15 +12,15 @@ type Server struct {
|
|||
msgChanIn chan *Message
|
||||
clients map[string]*Client
|
||||
clientsMutex sync.Mutex
|
||||
SessionManager *SessionManager
|
||||
sessionManager *SessionManager
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
func NewServer(msgChanIn chan *Message) *Server {
|
||||
func NewServer(msgChanIn chan *Message, sessionManager *SessionManager) *Server {
|
||||
return &Server{
|
||||
clients: make(map[string]*Client),
|
||||
msgChanIn: msgChanIn,
|
||||
SessionManager: NewSessionManager(),
|
||||
sessionManager: sessionManager,
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
@ -31,7 +31,7 @@ func NewServer(msgChanIn chan *Message) *Server {
|
|||
func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Info(err)
|
||||
return
|
||||
}
|
||||
client := NewClient(s, conn)
|
||||
|
@ -45,8 +45,11 @@ func (s *Server) AddClient(c *Client) {
|
|||
}
|
||||
if id := c.GetID(); id != "" {
|
||||
s.clientsMutex.Lock()
|
||||
defer s.clientsMutex.Unlock()
|
||||
s.clients[id] = c
|
||||
s.clientsMutex.Unlock()
|
||||
if s.sessionManager != nil {
|
||||
s.sessionManager.Init(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,6 +61,8 @@ func (s *Server) DelClient(c *Client) {
|
|||
s.clientsMutex.Lock()
|
||||
delete(s.clients, id)
|
||||
s.clientsMutex.Unlock()
|
||||
s.SessionManager.Remove(c)
|
||||
if s.sessionManager != nil {
|
||||
s.sessionManager.Remove(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
srv := NewServer(nil, NewSessionManager())
|
||||
assert.NotNil(srv)
|
||||
|
||||
req, _ := http.NewRequest("GET", "url", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.Handler(w, req)
|
||||
|
||||
out := make(chan *Message)
|
||||
c := &Client{
|
||||
out: out,
|
||||
server: srv,
|
||||
}
|
||||
srv.AddClient(nil)
|
||||
go srv.AddClient(c)
|
||||
msg := <-out
|
||||
assert.Equal(SessionMessageInit, msg.Subject)
|
||||
|
||||
srv.DelClient(nil)
|
||||
srv.DelClient(c)
|
||||
}
|
|
@ -40,38 +40,44 @@ func (s *SessionManager) HandleMessage(msg *Message) bool {
|
|||
s.clientToSession[id] = msg.ID
|
||||
s.sessionToClient[msg.ID] = list
|
||||
return true
|
||||
} else {
|
||||
} else if msg.From != nil {
|
||||
id := msg.From.GetID()
|
||||
msg.Session = s.clientToSession[id]
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
func (s *SessionManager) Remove(c *Client) {
|
||||
|
||||
// Remove clients from SessionManagerer
|
||||
// - 1. result: clients removed from session manager
|
||||
// - 2. result: session closed and all clients removed
|
||||
func (s *SessionManager) Remove(c *Client) (client bool, session bool) {
|
||||
if c == nil {
|
||||
return
|
||||
return false, false
|
||||
}
|
||||
if id := c.GetID(); id != "" {
|
||||
session := s.clientToSession[id]
|
||||
defer delete(s.clientToSession, id)
|
||||
if session != uuid.Nil {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
list := s.sessionToClient[session]
|
||||
delete(list, id)
|
||||
if len(list) > 0 {
|
||||
s.sessionToClient[session] = list
|
||||
clients := s.sessionToClient[session]
|
||||
delete(clients, id)
|
||||
if len(clients) > 0 {
|
||||
s.sessionToClient[session] = clients
|
||||
return true, false
|
||||
} else {
|
||||
delete(s.sessionToClient, session)
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
delete(s.clientToSession, id)
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (s *SessionManager) Send(id uuid.UUID, msg *Message) {
|
||||
session := s.sessionToClient[id]
|
||||
for _, c := range session {
|
||||
clients := s.sessionToClient[id]
|
||||
for _, c := range clients {
|
||||
c.Write(msg)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@ package websocket
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -17,13 +15,13 @@ func TestSessionManager(t *testing.T) {
|
|||
|
||||
out := make(chan *Message, channelBufSize)
|
||||
client := &Client{
|
||||
id: uuid.New(),
|
||||
out: out,
|
||||
writeQuit: make(chan bool),
|
||||
readQuit: make(chan bool),
|
||||
ws: &websocket.Conn{},
|
||||
}
|
||||
|
||||
session.Init(client)
|
||||
go session.Init(client)
|
||||
msg := <-out
|
||||
assert.Equal(SessionMessageInit, msg.Subject)
|
||||
|
||||
|
@ -35,9 +33,59 @@ func TestSessionManager(t *testing.T) {
|
|||
assert.False(result)
|
||||
|
||||
result = session.HandleMessage(&Message{
|
||||
ID: uuid.New(),
|
||||
ID: uuid.New(),
|
||||
From: client,
|
||||
})
|
||||
assert.False(result)
|
||||
|
||||
sessionID := uuid.New()
|
||||
result = session.HandleMessage(&Message{
|
||||
ID: sessionID,
|
||||
From: client,
|
||||
Subject: SessionMessageInit,
|
||||
})
|
||||
assert.True(result)
|
||||
|
||||
go session.Send(sessionID, &Message{
|
||||
Subject: "some trash",
|
||||
})
|
||||
msg = <-out
|
||||
assert.Equal("some trash", msg.Subject)
|
||||
|
||||
// a client need to disconnected
|
||||
c, s := session.Remove(nil)
|
||||
assert.False(c)
|
||||
assert.False(s)
|
||||
|
||||
out2 := make(chan *Message, channelBufSize)
|
||||
client2 := &Client{
|
||||
id: uuid.New(),
|
||||
out: out2,
|
||||
writeQuit: make(chan bool),
|
||||
readQuit: make(chan bool),
|
||||
}
|
||||
|
||||
go session.Init(client2)
|
||||
msg = <-out2
|
||||
result = session.HandleMessage(&Message{
|
||||
ID: sessionID,
|
||||
From: client2,
|
||||
Subject: SessionMessageInit,
|
||||
})
|
||||
assert.True(result)
|
||||
|
||||
// remove first client of session
|
||||
c, s = session.Remove(client)
|
||||
assert.True(c)
|
||||
assert.False(s)
|
||||
|
||||
// remove last client of session
|
||||
c, s = session.Remove(client2)
|
||||
assert.True(c)
|
||||
assert.True(s)
|
||||
|
||||
// all client disconnected already
|
||||
c, s = session.Remove(client2)
|
||||
assert.False(c)
|
||||
assert.False(s)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,9 @@ func TestWorker(t *testing.T) {
|
|||
go w.Start()
|
||||
time.Sleep(time.Duration(18) * time.Millisecond)
|
||||
w.Close()
|
||||
|
||||
time.Sleep(time.Duration(18) * time.Millisecond)
|
||||
assert.Equal(3, runtime)
|
||||
assert.Panics(func() {
|
||||
w.Close()
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue