From 0d64ba751ed66653e5d4ca0fff532acfddca5a12 Mon Sep 17 00:00:00 2001 From: Martin Geno Date: Wed, 25 Oct 2017 18:42:44 +0200 Subject: [PATCH] add websocket --- .test-coverage | 1 + websocket/client.go | 116 ++++++++++++++++++++++++++++++++++++++ websocket/msg.go | 35 ++++++++++++ websocket/msg_test.go | 24 ++++++++ websocket/server.go | 63 +++++++++++++++++++++ websocket/session.go | 77 +++++++++++++++++++++++++ websocket/session_test.go | 43 ++++++++++++++ worker/main.go | 9 ++- worker/main_test.go | 1 - 9 files changed, 367 insertions(+), 2 deletions(-) create mode 100644 websocket/client.go create mode 100644 websocket/msg.go create mode 100644 websocket/msg_test.go create mode 100644 websocket/server.go create mode 100644 websocket/session.go create mode 100644 websocket/session_test.go diff --git a/.test-coverage b/.test-coverage index cd46eda..73b6e7f 100755 --- a/.test-coverage +++ b/.test-coverage @@ -26,6 +26,7 @@ done # Failures have incomplete results, so don't send if [ "$FAIL" -eq 0 ]; then + bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN -f profile.cov goveralls -v -coverprofile=profile.cov -service=$CI -repotoken=$COVERALLS_REPO_TOKEN fi diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 0000000..904fcd6 --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,116 @@ +package websocket + +import ( + "io" + + log "github.com/sirupsen/logrus" + + "github.com/gorilla/websocket" +) + +const channelBufSize = 100 + +type Client struct { + server *Server + ws *websocket.Conn + out chan *Message + writeQuit chan bool + readQuit chan bool +} + +func NewClient(s *Server, ws *websocket.Conn) *Client { + + if ws == nil { + log.Panic("ws cannot be nil") + } + + return &Client{ + server: s, + ws: ws, + out: make(chan *Message, channelBufSize), + writeQuit: make(chan bool), + readQuit: make(chan bool), + } +} + +func (c *Client) GetID() string { + return c.ws.RemoteAddr().String() +} + +func (c *Client) Write(msg *Message) { + select { + case c.out <- msg: + default: + c.server.SessionManager.Remove(c) + c.server.DelClient(c) + c.Close() + } +} + +func (c *Client) Close() { + c.writeQuit <- true + c.readQuit <- true + log.Info("client disconnecting...", c.GetID()) +} + +// Listen Write and Read request via channel +func (c *Client) Listen() { + go c.listenWrite() + c.server.AddClient(c) + c.server.SessionManager.Init(c) + c.listenRead() +} + +func (c *Client) handleInput(msg *Message) { + msg.From = c + if c.server.SessionManager.HandleMessage(msg) { + return + } + if ok, err := msg.Validate(); ok { + c.server.msgChanIn <- msg + } else { + log.Println("no valid msg for:", c.GetID(), "error:", err, "\nmessage:", msg) + } +} + +// Listen write request via channel +func (c *Client) listenWrite() { + for { + select { + case msg := <-c.out: + websocket.WriteJSON(c.ws, msg) + + case <-c.writeQuit: + c.server.SessionManager.Remove(c) + c.server.DelClient(c) + close(c.out) + close(c.writeQuit) + return + } + } +} + +// Listen read request via channel +func (c *Client) listenRead() { + for { + select { + + case <-c.readQuit: + c.server.SessionManager.Remove(c) + c.server.DelClient(c) + close(c.readQuit) + return + + default: + var msg Message + err := websocket.ReadJSON(c.ws, &msg) + if err == io.EOF { + return + } else if err != nil { + log.Println(err, c.GetID()) + } else { + c.handleInput(&msg) + } + } + } +} diff --git a/websocket/msg.go b/websocket/msg.go new file mode 100644 index 0000000..272154d --- /dev/null +++ b/websocket/msg.go @@ -0,0 +1,35 @@ +package websocket + +import ( + "errors" + + "github.com/google/uuid" +) + +type Message struct { + ID uuid.UUID `json:"id,omitempty"` + Session uuid.UUID `json:"-"` + From *Client `json:"-"` + Subject string `json:"subject,omitempty"` + Body interface{} `json:"body,omitempty"` +} + +func (msg *Message) Validate() (bool, error) { + if msg.Subject == "" { + return false, errors.New("no subject definied") + } + if msg.From == nil { + return false, errors.New("no sender definied") + } + return true, nil +} + +func (msg *Message) Answer(subject string, body interface{}) { + msg.From.Write(&Message{ + ID: msg.ID, + Session: msg.Session, + From: msg.From, + Subject: subject, + Body: body, + }) +} diff --git a/websocket/msg_test.go b/websocket/msg_test.go new file mode 100644 index 0000000..943bcf2 --- /dev/null +++ b/websocket/msg_test.go @@ -0,0 +1,24 @@ +package websocket + + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMSGValidate(t *testing.T) { + assert := assert.New(t) + + msg := &Message{} + assert.False(msg.Validate()) + + msg.Subject = "login" + assert.False(msg.Validate()) + + msg.From = &Client{} + assert.True(msg.Validate()) + + msg.Subject = "" + assert.False(msg.Validate()) +} diff --git a/websocket/server.go b/websocket/server.go new file mode 100644 index 0000000..60262bf --- /dev/null +++ b/websocket/server.go @@ -0,0 +1,63 @@ +package websocket + +import ( + "net/http" + "sync" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" +) + +type Server struct { + msgChanIn chan *Message + clients map[string]*Client + clientsMutex sync.Mutex + SessionManager *SessionManager + upgrader websocket.Upgrader +} + +func NewServer(msgChanIn chan *Message) *Server { + return &Server{ + clients: make(map[string]*Client), + msgChanIn: msgChanIn, + SessionManager: NewSessionManager(), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + } +} + +func (s *Server) Handler(w http.ResponseWriter, r *http.Request) { + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(err) + return + } + client := NewClient(s, conn) + defer client.Close() + client.Listen() +} + +func (s *Server) AddClient(c *Client) { + if c == nil { + return + } + if id := c.GetID(); id != "" { + s.clientsMutex.Lock() + defer s.clientsMutex.Unlock() + s.clients[id] = c + } +} + +func (s *Server) DelClient(c *Client) { + if c == nil { + return + } + if id := c.GetID(); id != "" { + s.clientsMutex.Lock() + delete(s.clients, id) + s.clientsMutex.Unlock() + s.SessionManager.Remove(c) + } +} diff --git a/websocket/session.go b/websocket/session.go new file mode 100644 index 0000000..94dde0e --- /dev/null +++ b/websocket/session.go @@ -0,0 +1,77 @@ +package websocket + +import ( + "sync" + + "github.com/google/uuid" +) + +const SessionMessageInit = "session_init" + +type SessionManager struct { + sessionToClient map[uuid.UUID]map[string]*Client + clientToSession map[string]uuid.UUID + sync.Mutex +} + +func NewSessionManager() *SessionManager { + return &SessionManager{ + sessionToClient: make(map[uuid.UUID]map[string]*Client), + clientToSession: make(map[string]uuid.UUID), + } +} + +func (s *SessionManager) Init(c *Client) { + c.Write(&Message{Subject: SessionMessageInit}) +} +func (s *SessionManager) HandleMessage(msg *Message) bool { + if msg == nil { + return false + } + if msg.ID != uuid.Nil && msg.Subject == SessionMessageInit && msg.From != nil { + s.Lock() + defer s.Unlock() + list := s.sessionToClient[msg.ID] + if list == nil { + list = make(map[string]*Client) + } + id := msg.From.GetID() + list[id] = msg.From + s.clientToSession[id] = msg.ID + s.sessionToClient[msg.ID] = list + return true + } else { + id := msg.From.GetID() + msg.Session = s.clientToSession[id] + } + + return false +} +func (s *SessionManager) Remove(c *Client) { + if c == nil { + return + } + if id := c.GetID(); id != "" { + session := s.clientToSession[id] + if session != uuid.Nil { + s.Lock() + defer s.Unlock() + list := s.sessionToClient[session] + delete(list, id) + if len(list) > 0 { + s.sessionToClient[session] = list + } else { + delete(s.sessionToClient, session) + } + } + delete(s.clientToSession, id) + } + +} + +func (s *SessionManager) Send(id uuid.UUID, msg *Message) { + session := s.sessionToClient[id] + for _, c := range session { + c.Write(msg) + } +} diff --git a/websocket/session_test.go b/websocket/session_test.go new file mode 100644 index 0000000..efbb0e2 --- /dev/null +++ b/websocket/session_test.go @@ -0,0 +1,43 @@ +package websocket + +import ( + "testing" + + "github.com/gorilla/websocket" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestSessionManager(t *testing.T) { + assert := assert.New(t) + + session := NewSessionManager() + assert.NotNil(session) + + out := make(chan *Message, channelBufSize) + client := &Client{ + out: out, + writeQuit: make(chan bool), + readQuit: make(chan bool), + ws: &websocket.Conn{}, + } + + session.Init(client) + msg := <-out + assert.Equal(SessionMessageInit, msg.Subject) + + result := session.HandleMessage(nil) + assert.False(result) + + msgFillSession := &Message{} + result = session.HandleMessage(msgFillSession) + assert.False(result) + + result = session.HandleMessage(&Message{ + ID: uuid.New(), + From: client, + Subject: SessionMessageInit, + }) + assert.True(result) +} diff --git a/worker/main.go b/worker/main.go index 991d2f1..af4de9b 100644 --- a/worker/main.go +++ b/worker/main.go @@ -1,13 +1,17 @@ // Package with a lib for cronjobs to run in background package worker -import "time" +import ( + "sync" + "time" +) // Struct which handles the job type Worker struct { every time.Duration run func() quit chan struct{} + wg sync.WaitGroup } // Function to create a new Worker with a timestamp, run, every and it's function @@ -23,6 +27,7 @@ func NewWorker(every time.Duration, f func()) (w *Worker) { // Function to start the Worker // (please us it as a go routine with go w.Start()) func (w *Worker) Start() { + w.wg.Add(1) ticker := time.NewTicker(w.every) for { select { @@ -30,6 +35,7 @@ func (w *Worker) Start() { w.run() case <-w.quit: ticker.Stop() + w.wg.Done() return } } @@ -38,4 +44,5 @@ func (w *Worker) Start() { // Function to stop the Worker func (w *Worker) Close() { close(w.quit) + w.wg.Wait() } diff --git a/worker/main_test.go b/worker/main_test.go index 041729e..6dedd08 100644 --- a/worker/main_test.go +++ b/worker/main_test.go @@ -22,5 +22,4 @@ func TestWorker(t *testing.T) { w.Close() assert.Equal(3, runtime) - time.Sleep(time.Duration(8) * time.Millisecond) }