[TEST] improve websocket
This commit is contained in:
parent
80290a8a17
commit
9d60973ea1
|
@ -3,11 +3,11 @@ package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
_ "github.com/jinzhu/gorm/dialects/postgres"
|
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
|
||||||
|
|
||||||
"github.com/genofire/golang-lib/log"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database connection for writing purposes
|
// Database connection for writing purposes
|
||||||
|
@ -36,7 +36,7 @@ type Config struct {
|
||||||
|
|
||||||
// Function to open a database and set the given configuration
|
// Function to open a database and set the given configuration
|
||||||
func Open(c Config) (err error) {
|
func Open(c Config) (err error) {
|
||||||
writeLog := log.Log.WithField("db", "write")
|
writeLog := log.WithField("db", "write")
|
||||||
config = &c
|
config = &c
|
||||||
Write, err = gorm.Open(config.Type, config.Connection)
|
Write, err = gorm.Open(config.Type, config.Connection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -48,7 +48,7 @@ func Open(c Config) (err error) {
|
||||||
Write.Callback().Create().Remove("gorm:update_time_stamp")
|
Write.Callback().Create().Remove("gorm:update_time_stamp")
|
||||||
Write.Callback().Update().Remove("gorm:update_time_stamp")
|
Write.Callback().Update().Remove("gorm:update_time_stamp")
|
||||||
if len(config.ReadConnection) > 0 {
|
if len(config.ReadConnection) > 0 {
|
||||||
readLog := log.Log.WithField("db", "read")
|
readLog := log.WithField("db", "read")
|
||||||
Read, err = gorm.Open(config.Type, config.ReadConnection)
|
Read, err = gorm.Open(config.Type, config.ReadConnection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,6 +3,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
@ -11,6 +12,7 @@ import (
|
||||||
const channelBufSize = 100
|
const channelBufSize = 100
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
id uuid.UUID
|
||||||
server *Server
|
server *Server
|
||||||
ws *websocket.Conn
|
ws *websocket.Conn
|
||||||
out chan *Message
|
out chan *Message
|
||||||
|
@ -27,6 +29,7 @@ func NewClient(s *Server, ws *websocket.Conn) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
server: s,
|
server: s,
|
||||||
ws: ws,
|
ws: ws,
|
||||||
|
id: uuid.New(), // fallback id (for testing)
|
||||||
out: make(chan *Message, channelBufSize),
|
out: make(chan *Message, channelBufSize),
|
||||||
writeQuit: make(chan bool),
|
writeQuit: make(chan bool),
|
||||||
readQuit: make(chan bool),
|
readQuit: make(chan bool),
|
||||||
|
@ -34,14 +37,16 @@ func NewClient(s *Server, ws *websocket.Conn) *Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) GetID() string {
|
func (c *Client) GetID() string {
|
||||||
return c.ws.RemoteAddr().String()
|
if c.ws != nil {
|
||||||
|
return c.ws.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
return c.id.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Write(msg *Message) {
|
func (c *Client) Write(msg *Message) {
|
||||||
select {
|
select {
|
||||||
case c.out <- msg:
|
case c.out <- msg:
|
||||||
default:
|
default:
|
||||||
c.server.SessionManager.Remove(c)
|
|
||||||
c.server.DelClient(c)
|
c.server.DelClient(c)
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
@ -57,13 +62,12 @@ func (c *Client) Close() {
|
||||||
func (c *Client) Listen() {
|
func (c *Client) Listen() {
|
||||||
go c.listenWrite()
|
go c.listenWrite()
|
||||||
c.server.AddClient(c)
|
c.server.AddClient(c)
|
||||||
c.server.SessionManager.Init(c)
|
|
||||||
c.listenRead()
|
c.listenRead()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleInput(msg *Message) {
|
func (c *Client) handleInput(msg *Message) {
|
||||||
msg.From = c
|
msg.From = c
|
||||||
if c.server.SessionManager.HandleMessage(msg) {
|
if sm := c.server.sessionManager; sm != nil && sm.HandleMessage(msg) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok, err := msg.Validate(); ok {
|
if ok, err := msg.Validate(); ok {
|
||||||
|
@ -81,7 +85,6 @@ func (c *Client) listenWrite() {
|
||||||
websocket.WriteJSON(c.ws, msg)
|
websocket.WriteJSON(c.ws, msg)
|
||||||
|
|
||||||
case <-c.writeQuit:
|
case <-c.writeQuit:
|
||||||
c.server.SessionManager.Remove(c)
|
|
||||||
c.server.DelClient(c)
|
c.server.DelClient(c)
|
||||||
close(c.out)
|
close(c.out)
|
||||||
close(c.writeQuit)
|
close(c.writeQuit)
|
||||||
|
@ -96,7 +99,6 @@ func (c *Client) listenRead() {
|
||||||
select {
|
select {
|
||||||
|
|
||||||
case <-c.readQuit:
|
case <-c.readQuit:
|
||||||
c.server.SessionManager.Remove(c)
|
|
||||||
c.server.DelClient(c)
|
c.server.DelClient(c)
|
||||||
close(c.readQuit)
|
close(c.readQuit)
|
||||||
return
|
return
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClient(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
chanMsg := make(chan *Message)
|
||||||
|
sm := NewSessionManager()
|
||||||
|
|
||||||
|
srv := NewServer(chanMsg, sm)
|
||||||
|
|
||||||
|
assert.Panics(func() {
|
||||||
|
NewClient(srv, nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
client := NewClient(srv, &websocket.Conn{})
|
||||||
|
assert.NotNil(client)
|
||||||
|
|
||||||
|
client = &Client{
|
||||||
|
server: srv,
|
||||||
|
id: uuid.New(),
|
||||||
|
out: make(chan *Message, channelBufSize),
|
||||||
|
writeQuit: make(chan bool),
|
||||||
|
readQuit: make(chan bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
client.handleInput(&Message{})
|
||||||
|
|
||||||
|
go client.handleInput(&Message{Subject: "a"})
|
||||||
|
msg := <-chanMsg
|
||||||
|
assert.Equal("a", msg.Subject)
|
||||||
|
|
||||||
|
// msg catched by sessionManager -> not read from chanMsg needed
|
||||||
|
client.handleInput(&Message{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Subject: SessionMessageInit,
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
|
@ -1,18 +1,18 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMSGValidate(t *testing.T) {
|
func TestMSGValidate(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
|
|
||||||
msg := &Message{}
|
msg := &Message{}
|
||||||
assert.False(msg.Validate())
|
assert.False(msg.Validate())
|
||||||
|
|
||||||
msg.Subject = "login"
|
msg.Subject = "login"
|
||||||
assert.False(msg.Validate())
|
assert.False(msg.Validate())
|
||||||
|
|
||||||
|
@ -22,3 +22,31 @@ func TestMSGValidate(t *testing.T) {
|
||||||
msg.Subject = ""
|
msg.Subject = ""
|
||||||
assert.False(msg.Validate())
|
assert.False(msg.Validate())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMSGAnswer(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
out := make(chan *Message, channelBufSize)
|
||||||
|
client := &Client{
|
||||||
|
id: uuid.New(),
|
||||||
|
out: out,
|
||||||
|
writeQuit: make(chan bool),
|
||||||
|
readQuit: make(chan bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID := uuid.New()
|
||||||
|
|
||||||
|
msg := &Message{
|
||||||
|
From: client,
|
||||||
|
ID: conversationID,
|
||||||
|
}
|
||||||
|
|
||||||
|
go msg.Answer("hi", nil)
|
||||||
|
msg = <-out
|
||||||
|
|
||||||
|
assert.Equal(conversationID, msg.ID)
|
||||||
|
assert.Equal(uuid.Nil, msg.Session)
|
||||||
|
assert.Equal(client, msg.From)
|
||||||
|
assert.Equal("hi", msg.Subject)
|
||||||
|
assert.Nil(msg.Body)
|
||||||
|
}
|
||||||
|
|
|
@ -12,15 +12,15 @@ type Server struct {
|
||||||
msgChanIn chan *Message
|
msgChanIn chan *Message
|
||||||
clients map[string]*Client
|
clients map[string]*Client
|
||||||
clientsMutex sync.Mutex
|
clientsMutex sync.Mutex
|
||||||
SessionManager *SessionManager
|
sessionManager *SessionManager
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(msgChanIn chan *Message) *Server {
|
func NewServer(msgChanIn chan *Message, sessionManager *SessionManager) *Server {
|
||||||
return &Server{
|
return &Server{
|
||||||
clients: make(map[string]*Client),
|
clients: make(map[string]*Client),
|
||||||
msgChanIn: msgChanIn,
|
msgChanIn: msgChanIn,
|
||||||
SessionManager: NewSessionManager(),
|
sessionManager: sessionManager,
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
WriteBufferSize: 1024,
|
WriteBufferSize: 1024,
|
||||||
|
@ -31,7 +31,7 @@ func NewServer(msgChanIn chan *Message) *Server {
|
||||||
func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Info(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
client := NewClient(s, conn)
|
client := NewClient(s, conn)
|
||||||
|
@ -45,8 +45,11 @@ func (s *Server) AddClient(c *Client) {
|
||||||
}
|
}
|
||||||
if id := c.GetID(); id != "" {
|
if id := c.GetID(); id != "" {
|
||||||
s.clientsMutex.Lock()
|
s.clientsMutex.Lock()
|
||||||
defer s.clientsMutex.Unlock()
|
|
||||||
s.clients[id] = c
|
s.clients[id] = c
|
||||||
|
s.clientsMutex.Unlock()
|
||||||
|
if s.sessionManager != nil {
|
||||||
|
s.sessionManager.Init(c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,6 +61,8 @@ func (s *Server) DelClient(c *Client) {
|
||||||
s.clientsMutex.Lock()
|
s.clientsMutex.Lock()
|
||||||
delete(s.clients, id)
|
delete(s.clients, id)
|
||||||
s.clientsMutex.Unlock()
|
s.clientsMutex.Unlock()
|
||||||
s.SessionManager.Remove(c)
|
if s.sessionManager != nil {
|
||||||
|
s.sessionManager.Remove(c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServer(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
srv := NewServer(nil, NewSessionManager())
|
||||||
|
assert.NotNil(srv)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "url", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
srv.Handler(w, req)
|
||||||
|
|
||||||
|
out := make(chan *Message)
|
||||||
|
c := &Client{
|
||||||
|
out: out,
|
||||||
|
server: srv,
|
||||||
|
}
|
||||||
|
srv.AddClient(nil)
|
||||||
|
go srv.AddClient(c)
|
||||||
|
msg := <-out
|
||||||
|
assert.Equal(SessionMessageInit, msg.Subject)
|
||||||
|
|
||||||
|
srv.DelClient(nil)
|
||||||
|
srv.DelClient(c)
|
||||||
|
}
|
|
@ -40,38 +40,44 @@ func (s *SessionManager) HandleMessage(msg *Message) bool {
|
||||||
s.clientToSession[id] = msg.ID
|
s.clientToSession[id] = msg.ID
|
||||||
s.sessionToClient[msg.ID] = list
|
s.sessionToClient[msg.ID] = list
|
||||||
return true
|
return true
|
||||||
} else {
|
} else if msg.From != nil {
|
||||||
id := msg.From.GetID()
|
id := msg.From.GetID()
|
||||||
msg.Session = s.clientToSession[id]
|
msg.Session = s.clientToSession[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
func (s *SessionManager) Remove(c *Client) {
|
|
||||||
|
// Remove clients from SessionManagerer
|
||||||
|
// - 1. result: clients removed from session manager
|
||||||
|
// - 2. result: session closed and all clients removed
|
||||||
|
func (s *SessionManager) Remove(c *Client) (client bool, session bool) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return
|
return false, false
|
||||||
}
|
}
|
||||||
if id := c.GetID(); id != "" {
|
if id := c.GetID(); id != "" {
|
||||||
session := s.clientToSession[id]
|
session := s.clientToSession[id]
|
||||||
|
defer delete(s.clientToSession, id)
|
||||||
if session != uuid.Nil {
|
if session != uuid.Nil {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
list := s.sessionToClient[session]
|
clients := s.sessionToClient[session]
|
||||||
delete(list, id)
|
delete(clients, id)
|
||||||
if len(list) > 0 {
|
if len(clients) > 0 {
|
||||||
s.sessionToClient[session] = list
|
s.sessionToClient[session] = clients
|
||||||
|
return true, false
|
||||||
} else {
|
} else {
|
||||||
delete(s.sessionToClient, session)
|
delete(s.sessionToClient, session)
|
||||||
|
return true, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(s.clientToSession, id)
|
|
||||||
}
|
}
|
||||||
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionManager) Send(id uuid.UUID, msg *Message) {
|
func (s *SessionManager) Send(id uuid.UUID, msg *Message) {
|
||||||
session := s.sessionToClient[id]
|
clients := s.sessionToClient[id]
|
||||||
for _, c := range session {
|
for _, c := range clients {
|
||||||
c.Write(msg)
|
c.Write(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,6 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
@ -17,13 +15,13 @@ func TestSessionManager(t *testing.T) {
|
||||||
|
|
||||||
out := make(chan *Message, channelBufSize)
|
out := make(chan *Message, channelBufSize)
|
||||||
client := &Client{
|
client := &Client{
|
||||||
|
id: uuid.New(),
|
||||||
out: out,
|
out: out,
|
||||||
writeQuit: make(chan bool),
|
writeQuit: make(chan bool),
|
||||||
readQuit: make(chan bool),
|
readQuit: make(chan bool),
|
||||||
ws: &websocket.Conn{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Init(client)
|
go session.Init(client)
|
||||||
msg := <-out
|
msg := <-out
|
||||||
assert.Equal(SessionMessageInit, msg.Subject)
|
assert.Equal(SessionMessageInit, msg.Subject)
|
||||||
|
|
||||||
|
@ -35,9 +33,59 @@ func TestSessionManager(t *testing.T) {
|
||||||
assert.False(result)
|
assert.False(result)
|
||||||
|
|
||||||
result = session.HandleMessage(&Message{
|
result = session.HandleMessage(&Message{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
|
From: client,
|
||||||
|
})
|
||||||
|
assert.False(result)
|
||||||
|
|
||||||
|
sessionID := uuid.New()
|
||||||
|
result = session.HandleMessage(&Message{
|
||||||
|
ID: sessionID,
|
||||||
From: client,
|
From: client,
|
||||||
Subject: SessionMessageInit,
|
Subject: SessionMessageInit,
|
||||||
})
|
})
|
||||||
assert.True(result)
|
assert.True(result)
|
||||||
|
|
||||||
|
go session.Send(sessionID, &Message{
|
||||||
|
Subject: "some trash",
|
||||||
|
})
|
||||||
|
msg = <-out
|
||||||
|
assert.Equal("some trash", msg.Subject)
|
||||||
|
|
||||||
|
// a client need to disconnected
|
||||||
|
c, s := session.Remove(nil)
|
||||||
|
assert.False(c)
|
||||||
|
assert.False(s)
|
||||||
|
|
||||||
|
out2 := make(chan *Message, channelBufSize)
|
||||||
|
client2 := &Client{
|
||||||
|
id: uuid.New(),
|
||||||
|
out: out2,
|
||||||
|
writeQuit: make(chan bool),
|
||||||
|
readQuit: make(chan bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
go session.Init(client2)
|
||||||
|
msg = <-out2
|
||||||
|
result = session.HandleMessage(&Message{
|
||||||
|
ID: sessionID,
|
||||||
|
From: client2,
|
||||||
|
Subject: SessionMessageInit,
|
||||||
|
})
|
||||||
|
assert.True(result)
|
||||||
|
|
||||||
|
// remove first client of session
|
||||||
|
c, s = session.Remove(client)
|
||||||
|
assert.True(c)
|
||||||
|
assert.False(s)
|
||||||
|
|
||||||
|
// remove last client of session
|
||||||
|
c, s = session.Remove(client2)
|
||||||
|
assert.True(c)
|
||||||
|
assert.True(s)
|
||||||
|
|
||||||
|
// all client disconnected already
|
||||||
|
c, s = session.Remove(client2)
|
||||||
|
assert.False(c)
|
||||||
|
assert.False(s)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,9 @@ func TestWorker(t *testing.T) {
|
||||||
go w.Start()
|
go w.Start()
|
||||||
time.Sleep(time.Duration(18) * time.Millisecond)
|
time.Sleep(time.Duration(18) * time.Millisecond)
|
||||||
w.Close()
|
w.Close()
|
||||||
|
time.Sleep(time.Duration(18) * time.Millisecond)
|
||||||
assert.Equal(3, runtime)
|
assert.Equal(3, runtime)
|
||||||
|
assert.Panics(func() {
|
||||||
|
w.Close()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue