golang-lib/web/ws/main.go

204 lines
4.9 KiB
Go
Raw Permalink Normal View History

2021-06-01 10:51:35 +02:00
package ws
import (
"context"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
2021-06-01 18:08:29 +02:00
"github.com/google/uuid"
2021-09-29 13:08:43 +02:00
"go.uber.org/zap"
2021-06-01 10:51:35 +02:00
"golang.org/x/time/rate"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)
2021-06-01 18:08:29 +02:00
// WebsocketEndpoint to handle Request
2021-06-01 10:51:35 +02:00
type WebsocketEndpoint struct {
2021-09-29 13:08:43 +02:00
log *zap.Logger
2021-06-01 10:51:35 +02:00
// publishLimiter controls the rate limit applied to the publish endpoint.
//
// Defaults to one publish every 100ms with a burst of 8.
publishLimiter *rate.Limiter
subscribersMu sync.Mutex
Subscribers map[*Subscriber]struct{}
// Message Handler
2021-06-01 18:08:29 +02:00
handlers map[string]MessageHandleFunc
// DefaultMessageHandler if no other handler for MessageType found
2021-06-01 10:51:35 +02:00
DefaultMessageHandler MessageHandleFunc
2021-06-01 18:08:29 +02:00
// Run Function on open connection by subscriper
OnOpen SubscriberEventFunc
// Run Function on close connection to subscriper
OnClose SubscriberEventFunc
2021-06-01 10:51:35 +02:00
}
2021-06-01 18:08:29 +02:00
// Subscriber of websocket endpoint
type Subscriber struct {
out chan *Message
closeSlow func()
}
2021-06-01 10:51:35 +02:00
2021-06-01 18:08:29 +02:00
// SubscriberEventFunc for handling connection state of Subsriber
2021-06-01 10:51:35 +02:00
type SubscriberEventFunc func(s *Subscriber, msg chan<- *Message)
// Message on websocket
type Message struct {
Type string `json:"type"`
2021-06-01 18:08:29 +02:00
ID *uuid.UUID `json:"id,omitempty"`
ReplyID *uuid.UUID `json:"reply_id,omitempty"`
2021-06-01 10:51:35 +02:00
Body map[string]interface{} `json:"body"`
Subscriber *Subscriber `json:"-"`
}
2021-06-01 18:08:29 +02:00
// Reply to Message
func (m *Message) Reply(msg *Message) {
if m == nil || m.Subscriber == nil {
return
}
if m.ID != nil {
msg.ReplyID = m.ID
if msg.ID == nil {
id := uuid.New()
msg.ID = &id
}
}
m.Subscriber.out <- msg
}
// MessageHandleFunc for handling messages
type MessageHandleFunc func(ctx context.Context, msg *Message)
// NewEndpoint - create an empty websocket
2021-09-29 13:08:43 +02:00
func NewEndpoint(log *zap.Logger) *WebsocketEndpoint {
2021-06-01 10:51:35 +02:00
return &WebsocketEndpoint{
2021-09-29 13:08:43 +02:00
log: log,
2021-06-01 10:51:35 +02:00
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
Subscribers: make(map[*Subscriber]struct{}),
handlers: make(map[string]MessageHandleFunc),
}
}
2021-06-01 18:08:29 +02:00
// Broadcast Message to all subscriber (exclude sender of Message)
2021-06-01 18:09:40 +02:00
func (we *WebsocketEndpoint) Broadcast(msg *Message) {
we.subscribersMu.Lock()
defer we.subscribersMu.Unlock()
2021-06-01 18:08:29 +02:00
2021-06-01 18:09:40 +02:00
we.publishLimiter.Wait(context.Background())
2021-06-01 18:08:29 +02:00
2021-06-01 18:09:40 +02:00
for s := range we.Subscribers {
2021-06-01 18:08:29 +02:00
if s == msg.Subscriber {
continue
}
select {
case s.out <- msg:
default:
go s.closeSlow()
}
}
}
// AddMessageHandler - add websocket message handler
func (we *WebsocketEndpoint) AddMessageHandler(typ string, f MessageHandleFunc) {
we.handlers[typ] = f
}
// Handler - to register in gin webservice
func (we *WebsocketEndpoint) Handler(ctx *gin.Context) {
c, err := websocket.Accept(ctx.Writer, ctx.Request, nil)
2021-06-01 10:51:35 +02:00
if err != nil {
ctx.JSON(http.StatusBadRequest, false)
return
}
defer c.Close(websocket.StatusInternalError, "")
2021-06-01 18:08:29 +02:00
err = we.addSubscriber(ctx, c)
2021-06-01 10:51:35 +02:00
if websocket.CloseStatus(err) == websocket.StatusNormalClosure ||
websocket.CloseStatus(err) == websocket.StatusGoingAway {
return
}
2021-09-29 13:08:43 +02:00
we.log.Error("subscriber stopped", zap.Error(err))
2021-06-01 10:51:35 +02:00
}
2021-06-01 18:08:29 +02:00
// addSubscriber and startup of websocket endpoint
func (we *WebsocketEndpoint) addSubscriber(ctxGin *gin.Context, c *websocket.Conn) error {
2021-06-01 10:51:35 +02:00
s := &Subscriber{
out: make(chan *Message, 10),
closeSlow: func() {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
},
}
2021-06-01 18:08:29 +02:00
we.subscribersMu.Lock()
we.Subscribers[s] = struct{}{}
we.subscribersMu.Unlock()
2021-06-01 10:51:35 +02:00
defer func() {
2021-06-01 18:08:29 +02:00
we.subscribersMu.Lock()
delete(we.Subscribers, s)
we.subscribersMu.Unlock()
if we.OnClose != nil {
we.OnClose(s, s.out)
2021-06-01 10:51:35 +02:00
}
2021-09-29 13:08:43 +02:00
we.log.Debug("websocket closed")
2021-06-01 10:51:35 +02:00
}()
2021-06-01 18:08:29 +02:00
if we.OnOpen != nil {
we.OnOpen(s, s.out)
2021-06-01 10:51:35 +02:00
}
ctx := ctxGin.Request.Context()
go func() {
2021-06-01 18:08:29 +02:00
err := we.readWorker(ctx, c, s)
2021-06-01 10:51:35 +02:00
if websocket.CloseStatus(err) == websocket.StatusNormalClosure ||
websocket.CloseStatus(err) == websocket.StatusGoingAway {
return
}
2021-09-29 13:08:43 +02:00
we.log.Error("websocket reading error", zap.Error(err))
2021-06-01 10:51:35 +02:00
}()
2021-06-01 18:08:29 +02:00
2021-09-29 13:08:43 +02:00
we.log.Debug("websocket started")
2021-06-01 10:51:35 +02:00
for {
select {
case msg := <-s.out:
err := writeTimeout(ctx, time.Second*5, c, msg)
if err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
2021-06-01 18:08:29 +02:00
// readWorker of subscriber
func (we *WebsocketEndpoint) readWorker(ctx context.Context, c *websocket.Conn, s *Subscriber) error {
for {
var msg Message
err := wsjson.Read(ctx, c, &msg)
if err != nil {
return err
}
2021-09-29 13:08:43 +02:00
we.log.Debug("receive", zap.String("type", msg.Type))
2021-06-01 18:08:29 +02:00
msg.Subscriber = s
if handler, ok := we.handlers[msg.Type]; ok {
handler(ctx, &msg)
} else if we.DefaultMessageHandler != nil {
we.DefaultMessageHandler(ctx, &msg)
}
}
}
// writeTimeout send message to subscriber with timeout
2021-06-01 10:51:35 +02:00
func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg *Message) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return wsjson.Write(ctx, c, msg)
}