From 1aceea71335e3793c4123a8358e5d0d2ea430808 Mon Sep 17 00:00:00 2001 From: Martin Geno Date: Sat, 16 Dec 2017 23:20:46 +0100 Subject: [PATCH] restructur code in packages --- cmd/server.go | 6 + model/account.go | 7 +- model/buddy.go | 17 + server/client.go | 76 ---- server/extension/main.go | 11 + server/extension/message.go | 15 + server/extension/roster.go | 27 ++ server/server.go | 28 +- server/state.go | 6 - server/state/connect.go | 125 +++++++ server/state/state.go | 8 + server/state_connect.go | 353 ------------------ server/toclient/connect.go | 272 ++++++++++++++ .../register.go} | 79 ++-- server/toclient/utils.go | 14 + server/utils/client.go | 65 ++++ server/{utils.go => utils/main.go} | 18 +- 17 files changed, 636 insertions(+), 491 deletions(-) create mode 100644 model/buddy.go delete mode 100644 server/client.go create mode 100644 server/extension/main.go create mode 100644 server/extension/message.go create mode 100644 server/extension/roster.go delete mode 100644 server/state.go create mode 100644 server/state/connect.go create mode 100644 server/state/state.go delete mode 100644 server/state_connect.go create mode 100644 server/toclient/connect.go rename server/{state_register.go => toclient/register.go} (50%) create mode 100644 server/toclient/utils.go create mode 100644 server/utils/client.go rename server/{utils.go => utils/main.go} (50%) diff --git a/cmd/server.go b/cmd/server.go index a786540..4ffeb45 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -12,6 +12,7 @@ import ( "github.com/genofire/yaja/database" "github.com/genofire/yaja/model/config" + "github.com/genofire/yaja/server/extension" "github.com/genofire/golang-lib/file" "github.com/genofire/golang-lib/worker" @@ -29,6 +30,7 @@ var ( statesaveWorker *worker.Worker srv *server.Server certs *tls.Config + extensions []extension.Extension ) // serverCmd represents the serve command @@ -81,6 +83,7 @@ var serverCmd = &cobra.Command{ LoggingClient: configData.Logging.LevelClient, RegisterEnable: configData.Register.Enable, RegisterDomains: configData.Register.Domains, + Extensions: extensions, } go statesaveWorker.Start() @@ -161,6 +164,7 @@ func reload() { LoggingClient: configNewData.Logging.LevelClient, RegisterEnable: configNewData.Register.Enable, RegisterDomains: configNewData.Register.Domains, + Extensions: extensions, } log.Warn("reloading need a restart:") go newServer.Start() @@ -176,4 +180,6 @@ func reload() { func init() { RootCmd.AddCommand(serverCmd) serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") + + extensions = append(extensions, &extension.Message{}, &extension.Roster{Database: db}) } diff --git a/model/account.go b/model/account.go index d59a432..46117ce 100644 --- a/model/account.go +++ b/model/account.go @@ -29,9 +29,10 @@ func (d *Domain) UpdateAccount(a *Account) error { } type Account struct { - Local string `json:"-"` - Domain *Domain `json:"-"` - Password string `json:"password"` + Local string `json:"-"` + Domain *Domain `json:"-"` + Password string `json:"password"` + Roster map[string]*Buddy `json:"roster"` } func NewAccount(jid *JID, password string) *Account { diff --git a/model/buddy.go b/model/buddy.go new file mode 100644 index 0000000..e321171 --- /dev/null +++ b/model/buddy.go @@ -0,0 +1,17 @@ +package model + +const ( + SubscriptionNone = iota + SubscriptionTo + SubscriptionFrom + SubscriptionBoth + AskNone = iota + AskSubscribe +) + +type Buddy struct { + Name string `json:"name"` + Groups []string `json:"groups"` + Subscription int `json:"subscription"` + Ask int `json:"ask"` +} diff --git a/server/client.go b/server/client.go deleted file mode 100644 index 717d97b..0000000 --- a/server/client.go +++ /dev/null @@ -1,76 +0,0 @@ -package server - -import ( - "encoding/xml" - "net" - - "github.com/genofire/yaja/model" - log "github.com/sirupsen/logrus" -) - -type Client struct { - log *log.Entry - - Conn net.Conn - out *xml.Encoder - in *xml.Decoder - - Server *Server - jid *model.JID - account *model.Account - - messages chan interface{} - close chan interface{} -} - -func NewClient(conn net.Conn, srv *Server) *Client { - logger := log.New() - logger.SetLevel(srv.LoggingClient) - client := &Client{ - Conn: conn, - Server: srv, - log: log.NewEntry(logger), - in: xml.NewDecoder(conn), - out: xml.NewEncoder(conn), - } - return client -} - -func (client *Client) NewConnecting(conn net.Conn) { - client.Conn = conn - client.in = xml.NewDecoder(conn) - client.out = xml.NewEncoder(conn) -} - -func (client *Client) Read() (*xml.StartElement, error) { - for { - nextToken, err := client.in.Token() - if err != nil { - return nil, err - } - switch nextToken.(type) { - case xml.StartElement: - element := nextToken.(xml.StartElement) - return &element, nil - } - } -} - -func (client *Client) DomainRegisterAllowed() bool { - if client.jid.Domain == "" { - return false - } - - for _, domain := range client.Server.RegisterDomains { - if domain == client.jid.Domain { - - return !client.Server.RegisterEnable - } - } - return client.Server.RegisterEnable -} - -func (client *Client) Close() { - client.close <- true - client.Conn.Close() -} diff --git a/server/extension/main.go b/server/extension/main.go new file mode 100644 index 0000000..d6e42a1 --- /dev/null +++ b/server/extension/main.go @@ -0,0 +1,11 @@ +package extension + +import ( + "encoding/xml" + + "github.com/genofire/yaja/server/utils" +) + +type Extension interface { + Process(*xml.StartElement, *utils.Client) bool +} diff --git a/server/extension/message.go b/server/extension/message.go new file mode 100644 index 0000000..bc8a7ce --- /dev/null +++ b/server/extension/message.go @@ -0,0 +1,15 @@ +package extension + +import ( + "encoding/xml" + + "github.com/genofire/yaja/server/utils" +) + +type Message struct { + Extension +} + +func (m *Message) Process(element *xml.StartElement, client *utils.Client) bool { + return false +} diff --git a/server/extension/roster.go b/server/extension/roster.go new file mode 100644 index 0000000..593f7eb --- /dev/null +++ b/server/extension/roster.go @@ -0,0 +1,27 @@ +package extension + +import ( + "encoding/xml" + + "github.com/genofire/yaja/database" + "github.com/genofire/yaja/messages" + "github.com/genofire/yaja/server/utils" +) + +type Roster struct { + Extension + Database *database.State +} + +func (r *Roster) Process(element *xml.StartElement, client *utils.Client) bool { + var msg messages.IQ + if err := client.In.DecodeElement(&msg, element); err != nil { + client.Log.Warn("is no iq: ", err) + return false + } + if msg.Type != messages.IQTypeGet { + client.Log.Warn("is no get iq") + return false + } + return true +} diff --git a/server/server.go b/server/server.go index 805b33a..5c5d686 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,10 @@ import ( "net" "github.com/genofire/yaja/database" + "github.com/genofire/yaja/model" + "github.com/genofire/yaja/server/extension" + "github.com/genofire/yaja/server/toclient" + "github.com/genofire/yaja/server/utils" log "github.com/sirupsen/logrus" "golang.org/x/crypto/acme/autocert" ) @@ -16,8 +20,9 @@ type Server struct { ServerAddr []string Database *database.State LoggingClient log.Level - RegisterEnable bool `toml:"enable"` - RegisterDomains []string `toml:"domains"` + RegisterEnable bool + RegisterDomains []string + Extensions []extension.Extension } func (srv *Server) Start() { @@ -68,13 +73,13 @@ func (srv *Server) handleServer(conn net.Conn) { func (srv *Server) handleClient(conn net.Conn) { log.Info("new client connection:", conn.RemoteAddr()) - client := NewClient(conn, srv) - state := ConnectionStartup() + client := utils.NewClient(conn, srv.LoggingClient) + state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed) for { state, client = state.Process(client) if state == nil { - client.log.Info("disconnect") + client.Log.Info("disconnect") client.Close() //s.DisconnectBus <- Disconnect{Jid: client.jid} return @@ -83,6 +88,19 @@ func (srv *Server) handleClient(conn net.Conn) { } } +func (srv *Server) DomainRegisterAllowed(jid *model.JID) bool { + if jid.Domain == "" { + return false + } + + for _, domain := range srv.RegisterDomains { + if domain == jid.Domain { + return !srv.RegisterEnable + } + } + return srv.RegisterEnable +} + func (srv *Server) Close() { } diff --git a/server/state.go b/server/state.go deleted file mode 100644 index 22b2f13..0000000 --- a/server/state.go +++ /dev/null @@ -1,6 +0,0 @@ -package server - -// State processes the stream and moves to the next state -type State interface { - Process(client *Client) (State, *Client) -} diff --git a/server/state/connect.go b/server/state/connect.go new file mode 100644 index 0000000..e4529fd --- /dev/null +++ b/server/state/connect.go @@ -0,0 +1,125 @@ +package state + +import ( + "crypto/tls" + "fmt" + + "github.com/genofire/yaja/messages" + "github.com/genofire/yaja/model" + "github.com/genofire/yaja/server/utils" + "golang.org/x/crypto/acme/autocert" +) + +// ConnectionStartup return steps through TCP TLS state +func ConnectionStartup(after State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager) State { + tlsupgrade := &TLSUpgrade{ + Next: after, + tlsconfig: tlsconfig, + tlsmgmt: tlsmgmt, + } + stream := &Start{Next: tlsupgrade} + return stream +} + +// Start state +type Start struct { + Next State +} + +// Process message +func (state *Start) Process(client *utils.Client) (State, *utils.Client) { + client.Log = client.Log.WithField("state", "stream") + client.Log.Debug("running") + defer client.Log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { + client.Log.Warn("is no stream") + return state, client + } + for _, attr := range element.Attr { + if attr.Name.Local == "to" { + client.JID = &model.JID{Domain: attr.Value} + client.Log = client.Log.WithField("jid", client.JID.Full()) + } + } + if client.JID == nil { + client.Log.Warn("no 'to' domain readed") + return nil, client + } + + fmt.Fprintf(client.Conn, ` + `, + utils.CreateCookie(), messages.NSClient, messages.NSStream) + + fmt.Fprintf(client.Conn, ` + + + + `, + messages.NSStream) + + return state.Next, client +} + +// TLSUpgrade state +type TLSUpgrade struct { + Next State + tlsconfig *tls.Config + tlsmgmt *autocert.Manager +} + +// Process message +func (state *TLSUpgrade) Process(client *utils.Client) (State, *utils.Client) { + client.Log = client.Log.WithField("state", "tls upgrade") + client.Log.Debug("running") + defer client.Log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" { + client.Log.Warn("is no starttls") + return state, client + } + fmt.Fprintf(client.Conn, "", messages.NSTLS) + // perform the TLS handshake + var tlsConfig *tls.Config + if m := state.tlsmgmt; m != nil { + var cert *tls.Certificate + cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.JID.Domain}) + if err != nil { + client.Log.Warn("no cert in tls manger found: ", err) + return nil, client + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + } + if tlsConfig == nil { + tlsConfig = state.tlsconfig + if tlsConfig != nil { + tlsConfig.ServerName = client.JID.Domain + } else { + client.Log.Warn("no tls config found: ", err) + return nil, client + } + } + + tlsConn := tls.Server(client.Conn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + client.Log.Warn("unable to tls handshake: ", err) + return nil, client + } + // restart the Connection + client.SetConnecting(tlsConn) + + return state.Next, client +} diff --git a/server/state/state.go b/server/state/state.go new file mode 100644 index 0000000..ef7f4f5 --- /dev/null +++ b/server/state/state.go @@ -0,0 +1,8 @@ +package state + +import "github.com/genofire/yaja/server/utils" + +// State processes the stream and moves to the next state +type State interface { + Process(client *utils.Client) (State, *utils.Client) +} diff --git a/server/state_connect.go b/server/state_connect.go deleted file mode 100644 index b54bc4d..0000000 --- a/server/state_connect.go +++ /dev/null @@ -1,353 +0,0 @@ -package server - -import ( - "crypto/tls" - "encoding/base64" - "encoding/xml" - "fmt" - "strings" - - "github.com/genofire/yaja/messages" - "github.com/genofire/yaja/model" -) - -// ConnectionStartup return steps through TCP TLS state -func ConnectionStartup() State { - receiving := &ReceivingClient{} - sending := &SendingClient{Next: receiving} - authedstream := &AuthedStream{Next: sending} - authedstart := &AuthedStart{Next: authedstream} - tlsauth := &SASLAuth{Next: authedstart} - tlsstream := &TLSStream{Next: tlsauth} - tlsupgrade := &TLSUpgrade{Next: tlsstream} - stream := &Start{Next: tlsupgrade} - return stream -} - -// Start state -type Start struct { - Next State -} - -// Process message -func (state *Start) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "stream") - client.log.Debug("running") - defer client.log.Debug("leave") - - element, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { - client.log.Warn("is no stream") - return state, client - } - for _, attr := range element.Attr { - if attr.Name.Local == "to" { - client.jid = &model.JID{Domain: attr.Value} - client.log = client.log.WithField("jid", client.jid.Full()) - } - } - if client.jid == nil { - client.log.Warn("no 'to' domain readed") - return nil, client - } - - fmt.Fprintf(client.Conn, ` - `, - createCookie(), messages.NSClient, messages.NSStream) - - fmt.Fprintf(client.Conn, ` - - - - `, - messages.NSStream) - - return state.Next, client -} - -// TLSUpgrade state -type TLSUpgrade struct { - Next State -} - -// Process message -func (state *TLSUpgrade) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "tls upgrade") - client.log.Debug("running") - defer client.log.Debug("leave") - - element, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" { - client.log.Warn("is no starttls") - return state, client - } - fmt.Fprintf(client.Conn, "", messages.NSTLS) - // perform the TLS handshake - var tlsConfig *tls.Config - if m := client.Server.TLSManager; m != nil { - var cert *tls.Certificate - cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain}) - if err != nil { - client.log.Warn("no cert in tls manger found: ", err) - return nil, client - } - tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{*cert}, - } - } - if tlsConfig == nil { - tlsConfig = client.Server.TLSConfig - if tlsConfig != nil { - tlsConfig.ServerName = client.jid.Domain - } else { - client.log.Warn("no tls config found: ", err) - return nil, client - } - } - - tlsConn := tls.Server(client.Conn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - client.log.Warn("unable to tls handshake: ", err) - return nil, client - } - // restart the Connection - client.NewConnecting(tlsConn) - - return state.Next, client -} - -// TLSStream state -type TLSStream struct { - Next State -} - -// Process messages -func (state *TLSStream) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "tls stream") - client.log.Debug("running") - defer client.log.Debug("leave") - - element, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { - client.log.Warn("is no stream") - return state, client - } - - fmt.Fprintf(client.Conn, ` - `, - createCookie(), messages.NSClient, messages.NSStream) - - if client.DomainRegisterAllowed() { - fmt.Fprintf(client.Conn, ` - - PLAIN - - - `, - messages.NSSASL, messages.NSFeaturesIQRegister) - } else { - fmt.Fprintf(client.Conn, ` - - PLAIN - - `, - messages.NSSASL) - } - - return state.Next, client -} - -// SASLAuth state -type SASLAuth struct { - Next State -} - -// Process messages -func (state *SASLAuth) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "sasl auth") - client.log.Debug("running") - defer client.log.Debug("leave") - - // read the full auth stanza - element, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - var auth messages.SASLAuth - if err = client.in.DecodeElement(&auth, element); err != nil { - client.log.Info("start substate for registration") - return &RegisterFormRequest{ - element: element, - Next: &RegisterRequest{ - Next: state.Next, - }, - }, client - } - data, err := base64.StdEncoding.DecodeString(auth.Body) - if err != nil { - client.log.Warn("body decode: ", err) - return nil, client - } - info := strings.Split(string(data), "\x00") - // should check that info[1] starts with client.jid - client.jid.Local = info[1] - client.log = client.log.WithField("jid", client.jid.Full()) - success, err := client.Server.Database.Authenticate(client.jid, info[2]) - if err != nil { - client.log.Warn("auth: ", err) - return nil, client - } - if success { - client.log.Info("success auth") - fmt.Fprintf(client.Conn, "", messages.NSSASL) - return state.Next, client - } - client.log.Warn("failed auth") - fmt.Fprintf(client.Conn, "", messages.NSSASL) - return nil, client - -} - -// AuthedStart state -type AuthedStart struct { - Next State -} - -// Process messages -func (state *AuthedStart) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "authed started") - client.log.Debug("running") - defer client.log.Debug("leave") - - _, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - fmt.Fprintf(client.Conn, ` - `, - createCookie(), messages.NSClient, messages.NSStream) - - fmt.Fprintf(client.Conn, ` - - `, - messages.NSBind) - - return state.Next, client -} - -// AuthedStream state -type AuthedStream struct { - Next State -} - -// Process messages -func (state *AuthedStream) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "authed stream") - client.log.Debug("running") - defer client.log.Debug("leave") - - // check that it's a bind request - // read bind request - element, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - var msg messages.IQ - if err = client.in.DecodeElement(&msg, element); err != nil { - client.log.Warn("is no iq: ", err) - return nil, client - } - if msg.Type != messages.IQTypeSet { - client.log.Warn("is no set iq") - return nil, client - } - if msg.Error != nil { - client.log.Warn("iq with error: ", msg.Error.Code) - return nil, client - } - type query struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` - Resource string `xml:"resource"` - } - q := &query{} - err = xml.Unmarshal(msg.Body, q) - if err != nil { - client.log.Warn("is no iq bind: ", err) - return nil, client - } - if q.Resource == "" { - client.jid.Resource = makeResource() - } else { - client.jid.Resource = q.Resource - } - client.log = client.log.WithField("jid", client.jid.Full()) - client.out.Encode(&messages.IQ{ - Type: messages.IQTypeResult, - ID: msg.ID, - Body: []byte(fmt.Sprintf( - ` - %s - `, - messages.NSBind, client.jid.Full())), - }) - - return state.Next, client -} - -// SendingClient state -type SendingClient struct { - Next State -} - -// Process messages -func (state *SendingClient) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "normal") - client.log.Debug("sending") - // sending - go func() { - select { - case msg := <-client.messages: - err := client.out.Encode(msg) - client.log.Info(err) - case <-client.close: - return - } - }() - client.log.Debug("receiving") - return state.Next, client -} - -// ReceivingClient state -type ReceivingClient struct { -} - -// Process messages -func (state *ReceivingClient) Process(client *Client) (State, *Client) { - element, err := client.Read() - if err != nil { - client.log.Warn("unable to read: ", err) - return nil, client - } - /* - for _, extension := range client.Server.Extensions { - extension.Process(element, client) - }*/ - client.log.Debug(element) - return state, client -} diff --git a/server/toclient/connect.go b/server/toclient/connect.go new file mode 100644 index 0000000..2be6ec3 --- /dev/null +++ b/server/toclient/connect.go @@ -0,0 +1,272 @@ +package toclient + +import ( + "crypto/tls" + "encoding/base64" + "encoding/xml" + "fmt" + "strings" + + "github.com/genofire/yaja/database" + "github.com/genofire/yaja/messages" + "github.com/genofire/yaja/server/extension" + "github.com/genofire/yaja/server/state" + "github.com/genofire/yaja/server/utils" + "golang.org/x/crypto/acme/autocert" +) + +// ConnectionStartup return steps through TCP TLS state +func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed) state.State { + receiving := &ReceivingClient{} + sending := &SendingClient{Next: receiving} + authedstream := &AuthedStream{Next: sending} + authedstart := &AuthedStart{Next: authedstream} + tlsauth := &SASLAuth{ + Next: authedstart, + database: db, + domainRegisterAllowed: registerAllowed, + } + tlsstream := &TLSStream{ + Next: tlsauth, + domainRegisterAllowed: registerAllowed, + } + return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt) +} + +// TLSStream state +type TLSStream struct { + Next state.State + domainRegisterAllowed utils.DomainRegisterAllowed +} + +// Process messages +func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "tls stream") + client.Log.Debug("running") + defer client.Log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { + client.Log.Warn("is no stream") + return state, client + } + + fmt.Fprintf(client.Conn, ` + `, + utils.CreateCookie(), messages.NSClient, messages.NSStream) + + if state.domainRegisterAllowed(client.JID) { + fmt.Fprintf(client.Conn, ` + + PLAIN + + + `, + messages.NSSASL, messages.NSFeaturesIQRegister) + } else { + fmt.Fprintf(client.Conn, ` + + PLAIN + + `, + messages.NSSASL) + } + + return state.Next, client +} + +// SASLAuth state +type SASLAuth struct { + Next state.State + database *database.State + domainRegisterAllowed utils.DomainRegisterAllowed +} + +// Process messages +func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "sasl auth") + client.Log.Debug("running") + defer client.Log.Debug("leave") + + // read the full auth stanza + element, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + var auth messages.SASLAuth + if err = client.In.DecodeElement(&auth, element); err != nil { + client.Log.Info("start substate for registration") + return &RegisterFormRequest{ + element: element, + domainRegisterAllowed: state.domainRegisterAllowed, + Next: &RegisterRequest{ + domainRegisterAllowed: state.domainRegisterAllowed, + database: state.database, + Next: state.Next, + }, + }, client + } + data, err := base64.StdEncoding.DecodeString(auth.Body) + if err != nil { + client.Log.Warn("body decode: ", err) + return nil, client + } + info := strings.Split(string(data), "\x00") + // should check that info[1] starts with client.JID + client.JID.Local = info[1] + client.Log = client.Log.WithField("jid", client.JID.Full()) + success, err := state.database.Authenticate(client.JID, info[2]) + if err != nil { + client.Log.Warn("auth: ", err) + return nil, client + } + if success { + client.Log.Info("success auth") + fmt.Fprintf(client.Conn, "", messages.NSSASL) + return state.Next, client + } + client.Log.Warn("failed auth") + fmt.Fprintf(client.Conn, "", messages.NSSASL) + return nil, client + +} + +// AuthedStart state +type AuthedStart struct { + Next state.State +} + +// Process messages +func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "authed started") + client.Log.Debug("running") + defer client.Log.Debug("leave") + + _, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + fmt.Fprintf(client.Conn, ` + `, + utils.CreateCookie(), messages.NSClient, messages.NSStream) + + fmt.Fprintf(client.Conn, ` + + `, + messages.NSBind) + + return state.Next, client +} + +// AuthedStream state +type AuthedStream struct { + Next state.State +} + +// Process messages +func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "authed stream") + client.Log.Debug("running") + defer client.Log.Debug("leave") + + // check that it's a bind request + // read bind request + element, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + var msg messages.IQ + if err = client.In.DecodeElement(&msg, element); err != nil { + client.Log.Warn("is no iq: ", err) + return nil, client + } + if msg.Type != messages.IQTypeSet { + client.Log.Warn("is no set iq") + return nil, client + } + if msg.Error != nil { + client.Log.Warn("iq with error: ", msg.Error.Code) + return nil, client + } + type query struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` + Resource string `xml:"resource"` + } + q := &query{} + err = xml.Unmarshal(msg.Body, q) + if err != nil { + client.Log.Warn("is no iq bind: ", err) + return nil, client + } + if q.Resource == "" { + client.JID.Resource = makeResource() + } else { + client.JID.Resource = q.Resource + } + client.Log = client.Log.WithField("jid", client.JID.Full()) + client.Out.Encode(&messages.IQ{ + Type: messages.IQTypeResult, + ID: msg.ID, + Body: []byte(fmt.Sprintf( + ` + %s + `, + messages.NSBind, client.JID.Full())), + }) + + return state.Next, client +} + +// SendingClient state +type SendingClient struct { + Next state.State +} + +// Process messages +func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "normal") + client.Log.Debug("sending") + // sending + go func() { + select { + case msg := <-client.Messages: + err := client.Out.Encode(msg) + client.Log.Info(err) + case <-client.OnClose(): + return + } + }() + client.Log.Debug("receiving") + return state.Next, client +} + +// ReceivingClient state +type ReceivingClient struct { + Extensions []extension.Extension +} + +// Process messages +func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) { + element, err := client.Read() + if err != nil { + client.Log.Warn("unable to read: ", err) + return nil, client + } + count := 0 + for _, extension := range state.Extensions { + if extension.Process(element, client) { + count++ + } + } + if count != 1 { + client.Log.WithField("extension", count).Debug(element) + } + return state, client +} diff --git a/server/state_register.go b/server/toclient/register.go similarity index 50% rename from server/state_register.go rename to server/toclient/register.go index 1d71835..4f381ed 100644 --- a/server/state_register.go +++ b/server/toclient/register.go @@ -1,40 +1,44 @@ -package server +package toclient import ( "encoding/xml" "fmt" + "github.com/genofire/yaja/database" "github.com/genofire/yaja/messages" "github.com/genofire/yaja/model" + "github.com/genofire/yaja/server/state" + "github.com/genofire/yaja/server/utils" ) type RegisterFormRequest struct { - Next State - element *xml.StartElement + Next state.State + domainRegisterAllowed utils.DomainRegisterAllowed + element *xml.StartElement } // Process message -func (state *RegisterFormRequest) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "register form request") - client.log.Debug("running") - defer client.log.Debug("leave") +func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "register form request") + client.Log.Debug("running") + defer client.Log.Debug("leave") - if !client.DomainRegisterAllowed() { - client.log.Error("unpossible to reach this state, register on this domain is not allowed") + if !state.domainRegisterAllowed(client.JID) { + client.Log.Error("unpossible to reach this state, register on this domain is not allowed") return nil, client } var msg messages.IQ - if err := client.in.DecodeElement(&msg, state.element); err != nil { - client.log.Warn("is no iq: ", err) + if err := client.In.DecodeElement(&msg, state.element); err != nil { + client.Log.Warn("is no iq: ", err) return state, client } if msg.Type != messages.IQTypeGet { - client.log.Warn("is no get iq") + client.Log.Warn("is no get iq") return state, client } if msg.Error != nil { - client.log.Warn("iq with error: ", msg.Error.Code) + client.Log.Warn("iq with error: ", msg.Error.Code) return state, client } type query struct { @@ -44,10 +48,10 @@ func (state *RegisterFormRequest) Process(client *Client) (State, *Client) { err := xml.Unmarshal(msg.Body, q) if q.XMLName.Space != messages.NSIQRegister || err != nil { - client.log.Warn("is no iq register: ", err) + client.Log.Warn("is no iq register: ", err) return nil, client } - client.out.Encode(&messages.IQ{ + client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, ID: msg.ID, Body: []byte(fmt.Sprintf(` @@ -61,36 +65,38 @@ func (state *RegisterFormRequest) Process(client *Client) (State, *Client) { } type RegisterRequest struct { - Next State + Next state.State + database *database.State + domainRegisterAllowed utils.DomainRegisterAllowed } // Process message -func (state *RegisterRequest) Process(client *Client) (State, *Client) { - client.log = client.log.WithField("state", "register request") - client.log.Debug("running") - defer client.log.Debug("leave") +func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils.Client) { + client.Log = client.Log.WithField("state", "register request") + client.Log.Debug("running") + defer client.Log.Debug("leave") - if !client.DomainRegisterAllowed() { - client.log.Error("unpossible to reach this state, register on this domain is not allowed") + if !state.domainRegisterAllowed(client.JID) { + client.Log.Error("unpossible to reach this state, register on this domain is not allowed") return nil, client } element, err := client.Read() if err != nil { - client.log.Warn("unable to read: ", err) + client.Log.Warn("unable to read: ", err) return nil, client } var msg messages.IQ - if err = client.in.DecodeElement(&msg, element); err != nil { - client.log.Warn("is no iq: ", err) + if err = client.In.DecodeElement(&msg, element); err != nil { + client.Log.Warn("is no iq: ", err) return state, client } if msg.Type != messages.IQTypeGet { - client.log.Warn("is no get iq") + client.Log.Warn("is no get iq") return state, client } if msg.Error != nil { - client.log.Warn("iq with error: ", msg.Error.Code) + client.Log.Warn("iq with error: ", msg.Error.Code) return state, client } type query struct { @@ -101,16 +107,16 @@ func (state *RegisterRequest) Process(client *Client) (State, *Client) { q := &query{} err = xml.Unmarshal(msg.Body, q) if err != nil { - client.log.Warn("is no iq register: ", err) + client.Log.Warn("is no iq register: ", err) return nil, client } - client.jid.Local = q.Username - client.log = client.log.WithField("jid", client.jid.Full()) - account := model.NewAccount(client.jid, q.Password) - err = client.Server.Database.AddAccount(account) + client.JID.Local = q.Username + client.Log = client.Log.WithField("jid", client.JID.Full()) + account := model.NewAccount(client.JID, q.Password) + err = state.database.AddAccount(account) if err != nil { - client.out.Encode(&messages.IQ{ + client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, ID: msg.ID, Body: []byte(fmt.Sprintf(` @@ -126,15 +132,14 @@ func (state *RegisterRequest) Process(client *Client) (State, *Client) { }, }, }) - client.log.Warn("database error: ", err) + client.Log.Warn("database error: ", err) return state, client } - client.account = account - client.out.Encode(&messages.IQ{ + client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, ID: msg.ID, }) - client.log.Infof("registered client %s", client.jid.Bare()) + client.Log.Infof("registered client %s", client.JID.Bare()) return state.Next, client } diff --git a/server/toclient/utils.go b/server/toclient/utils.go new file mode 100644 index 0000000..aec4fa4 --- /dev/null +++ b/server/toclient/utils.go @@ -0,0 +1,14 @@ +package toclient + +import ( + "crypto/rand" + "fmt" +) + +func makeResource() string { + var buf [16]byte + if _, err := rand.Reader.Read(buf[:]); err != nil { + panic("Failed to read random bytes: " + err.Error()) + } + return fmt.Sprintf("%x", buf) +} diff --git a/server/utils/client.go b/server/utils/client.go new file mode 100644 index 0000000..d175309 --- /dev/null +++ b/server/utils/client.go @@ -0,0 +1,65 @@ +package utils + +import ( + "encoding/xml" + "net" + + "github.com/genofire/yaja/model" + log "github.com/sirupsen/logrus" +) + +type Client struct { + Log *log.Entry + + Conn net.Conn + Out *xml.Encoder + In *xml.Decoder + + JID *model.JID + account *model.Account + + Messages chan interface{} + close chan interface{} +} + +func NewClient(conn net.Conn, level log.Level) *Client { + logger := log.New() + logger.SetLevel(level) + client := &Client{ + Conn: conn, + Log: log.NewEntry(logger), + In: xml.NewDecoder(conn), + Out: xml.NewEncoder(conn), + close: make(chan interface{}), + } + return client +} + +func (client *Client) SetConnecting(conn net.Conn) { + client.Conn = conn + client.In = xml.NewDecoder(conn) + client.Out = xml.NewEncoder(conn) +} + +func (client *Client) Read() (*xml.StartElement, error) { + for { + nextToken, err := client.In.Token() + if err != nil { + return nil, err + } + switch nextToken.(type) { + case xml.StartElement: + element := nextToken.(xml.StartElement) + return &element, nil + } + } +} + +func (client *Client) OnClose() <-chan interface{} { + return client.close +} + +func (client *Client) Close() { + client.close <- true + client.Conn.Close() +} diff --git a/server/utils.go b/server/utils/main.go similarity index 50% rename from server/utils.go rename to server/utils/main.go index 561161e..b5959f1 100644 --- a/server/utils.go +++ b/server/utils/main.go @@ -1,29 +1,25 @@ -package server +package utils import ( "crypto/rand" "encoding/binary" "fmt" + + "github.com/genofire/yaja/model" ) // Cookie is used to give a unique identifier to each request. type Cookie uint64 -func createCookie() Cookie { +func CreateCookie() Cookie { var buf [8]byte if _, err := rand.Reader.Read(buf[:]); err != nil { panic("Failed to read random bytes: " + err.Error()) } return Cookie(binary.LittleEndian.Uint64(buf[:])) } -func createCookieString() string { - return fmt.Sprintf("%x", createCookie()) +func CreateCookieString() string { + return fmt.Sprintf("%x", CreateCookie()) } -func makeResource() string { - var buf [16]byte - if _, err := rand.Reader.Read(buf[:]); err != nil { - panic("Failed to read random bytes: " + err.Error()) - } - return fmt.Sprintf("%x", buf) -} +type DomainRegisterAllowed func(*model.JID) bool