From e474f460aa2bc4dbaf472803d9b1c2650d13f92c Mon Sep 17 00:00:00 2001 From: Martin Geno Date: Sun, 17 Dec 2017 17:50:51 +0100 Subject: [PATCH] move client to state attr + s2s idea --- cmd/server.go | 70 ++++---- config_example.conf | 3 +- model/config/struct.go | 1 + server/extension/iq_disco.go | 3 +- server/extension/iq_discovery.go | 3 +- server/extension/iq_last.go | 3 +- server/extension/iq_ping.go | 3 +- server/extension/iq_private.go | 3 +- server/extension/iq_private_bookmarks.go | 3 +- server/extension/iq_private_metacontacts.go | 3 +- server/extension/iq_private_roster.go | 3 +- server/extension/iq_roster.go | 3 +- server/server.go | 45 +++-- server/state/connect.go | 101 +++++------ server/state/normal.go | 49 ++++++ server/state/state.go | 24 ++- server/toclient/connect.go | 186 +++++++++++--------- server/toclient/normal.go | 48 ----- server/toclient/register.go | 106 +++++------ server/toserver/connect.go | 141 +++++++++++++++ server/utils/client.go | 2 +- 21 files changed, 498 insertions(+), 305 deletions(-) create mode 100644 server/state/normal.go delete mode 100644 server/toclient/normal.go create mode 100644 server/toserver/connect.go diff --git a/cmd/server.go b/cmd/server.go index 43dccb0..8bd56e4 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -25,12 +25,13 @@ import ( var configPath string var ( - configData = &config.Config{} - db = &database.State{} - statesaveWorker *worker.Worker - srv *server.Server - certs *tls.Config - extensions extension.Extensions + configData = &config.Config{} + db = &database.State{} + statesaveWorker *worker.Worker + srv *server.Server + certs *tls.Config + extensionsClient extension.Extensions + extensionsServer extension.Extensions ) // serverCmd represents the serve command @@ -39,16 +40,14 @@ var serverCmd = &cobra.Command{ Short: "Runs the yaja server", Example: "yaja serve -c /etc/yaja.conf", Run: func(cmd *cobra.Command, args []string) { - var err error - err = file.ReadTOML(configPath, configData) - if err != nil { + + if err := file.ReadTOML(configPath, configData); err != nil { log.Fatal("unable to load config file:", err) } log.SetLevel(configData.Logging.Level) - err = file.ReadJSON(configData.StatePath, db) - if err != nil { + if err := file.ReadJSON(configData.StatePath, db); err != nil { log.Warn("unable to load state file:", err) } @@ -76,14 +75,16 @@ var serverCmd = &cobra.Command{ } srv = &server.Server{ - TLSManager: &m, - Database: db, - ClientAddr: configData.Address.Client, - ServerAddr: configData.Address.Server, - LoggingClient: configData.Logging.LevelClient, - RegisterEnable: configData.Register.Enable, - RegisterDomains: configData.Register.Domains, - Extensions: extensions, + TLSManager: &m, + Database: db, + ClientAddr: configData.Address.Client, + ServerAddr: configData.Address.Server, + LoggingClient: configData.Logging.LevelClient, + LoggingServer: configData.Logging.LevelServer, + RegisterEnable: configData.Register.Enable, + RegisterDomains: configData.Register.Domains, + ExtensionsServer: extensionsServer, + ExtensionsClient: extensionsClient, } go statesaveWorker.Start() @@ -122,13 +123,14 @@ func quit() { func reload() { log.Info("start reloading...") var configNewData *config.Config - err := file.ReadTOML(configPath, configNewData) - if err != nil { + + if err := file.ReadTOML(configPath, configNewData); err != nil { log.Warn("unable to load config file:", err) return } log.SetLevel(configNewData.Logging.Level) srv.LoggingClient = configNewData.Logging.LevelClient + srv.LoggingServer = configNewData.Logging.LevelServer srv.RegisterEnable = configNewData.Register.Enable srv.RegisterDomains = configNewData.Register.Domains @@ -157,14 +159,15 @@ func reload() { } if restartServer { newServer := &server.Server{ - TLSConfig: certs, - Database: db, - ClientAddr: configNewData.Address.Client, - ServerAddr: configNewData.Address.Server, - LoggingClient: configNewData.Logging.LevelClient, - RegisterEnable: configNewData.Register.Enable, - RegisterDomains: configNewData.Register.Domains, - Extensions: extensions, + TLSConfig: certs, + Database: db, + ClientAddr: configNewData.Address.Client, + ServerAddr: configNewData.Address.Server, + LoggingClient: configNewData.Logging.LevelClient, + RegisterEnable: configNewData.Register.Enable, + RegisterDomains: configNewData.Register.Domains, + ExtensionsServer: extensionsServer, + ExtensionsClient: extensionsClient, } log.Warn("reloading need a restart:") go newServer.Start() @@ -178,7 +181,7 @@ func reload() { } func init() { - extensions = append(extensions, + extensionsClient = append(extensionsClient, &extension.Message{}, &extension.Presence{}, extension.IQExtensions{ @@ -188,10 +191,15 @@ func init() { &extension.IQDisco{Database: db}, &extension.IQRoster{Database: db}, &extension.IQExtensionDiscovery{GetSpaces: func() []string { - return extensions.Spaces() + return extensionsClient.Spaces() }}, }) + extensionsServer = append(extensionsServer, + extension.IQExtensions{ + &extension.IQPing{}, + }) + RootCmd.AddCommand(serverCmd) serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") diff --git a/config_example.conf b/config_example.conf index 7ea1df2..f6a5940 100644 --- a/config_example.conf +++ b/config_example.conf @@ -2,8 +2,9 @@ tlsdir = "tmp/ssl" state_path = "tmp/yaja.json" [logging] -level = 3 +level = 5 level_client = 6 +level_server = 6 [register] enable = true diff --git a/model/config/struct.go b/model/config/struct.go index 98aa393..872ccef 100644 --- a/model/config/struct.go +++ b/model/config/struct.go @@ -10,6 +10,7 @@ type Config struct { Logging struct { Level log.Level `toml:"level"` LevelClient log.Level `toml:"level_client"` + LevelServer log.Level `toml:"level_server"` } `toml:"logging"` Register struct { Enable bool `toml:"enable"` diff --git a/server/extension/iq_disco.go b/server/extension/iq_disco.go index 1c65d41..f5b44fd 100644 --- a/server/extension/iq_disco.go +++ b/server/extension/iq_disco.go @@ -24,8 +24,7 @@ func (ex *IQDisco) Get(msg *messages.IQ, client *utils.Client) bool { Body []byte `xml:",innerxml"` } q := &query{} - err := xml.Unmarshal(msg.Body, q) - if err != nil { + if err := xml.Unmarshal(msg.Body, q); err != nil { return false } diff --git a/server/extension/iq_discovery.go b/server/extension/iq_discovery.go index 1ad4a77..4233295 100644 --- a/server/extension/iq_discovery.go +++ b/server/extension/iq_discovery.go @@ -23,8 +23,7 @@ func (ex *IQExtensionDiscovery) Get(msg *messages.IQ, client *utils.Client) bool Body []byte `xml:",innerxml"` } q := &query{} - err := xml.Unmarshal(msg.Body, q) - if err != nil { + if err := xml.Unmarshal(msg.Body, q); err != nil { return false } diff --git a/server/extension/iq_last.go b/server/extension/iq_last.go index 6e575ef..259fd43 100644 --- a/server/extension/iq_last.go +++ b/server/extension/iq_last.go @@ -25,8 +25,7 @@ func (ex *IQLast) Get(msg *messages.IQ, client *utils.Client) bool { Body []byte `xml:",innerxml"` } q := &query{} - err := xml.Unmarshal(msg.Body, q) - if err != nil { + if err := xml.Unmarshal(msg.Body, q); err != nil { return false } diff --git a/server/extension/iq_ping.go b/server/extension/iq_ping.go index eddcf27..62be193 100644 --- a/server/extension/iq_ping.go +++ b/server/extension/iq_ping.go @@ -21,8 +21,7 @@ func (ex *IQPing) Get(msg *messages.IQ, client *utils.Client) bool { XMLName xml.Name `xml:"urn:xmpp:ping ping"` } pq := &ping{} - err := xml.Unmarshal(msg.Body, pq) - if err != nil { + if err := xml.Unmarshal(msg.Body, pq); err != nil { return false } diff --git a/server/extension/iq_private.go b/server/extension/iq_private.go index 81d8d8d..b969662 100644 --- a/server/extension/iq_private.go +++ b/server/extension/iq_private.go @@ -27,8 +27,7 @@ func (ex *IQPrivate) Get(msg *messages.IQ, client *utils.Client) bool { // query encode q := &iqPrivateQuery{} - err := xml.Unmarshal(msg.Body, q) - if err != nil { + if err := xml.Unmarshal(msg.Body, q); err != nil { return false } diff --git a/server/extension/iq_private_bookmarks.go b/server/extension/iq_private_bookmarks.go index a9efe4c..e6f9b86 100644 --- a/server/extension/iq_private_bookmarks.go +++ b/server/extension/iq_private_bookmarks.go @@ -19,8 +19,7 @@ func (ex *IQPrivateBookmark) Handle(msg *messages.IQ, q *iqPrivateQuery, client XMLName xml.Name `xml:"storage:bookmarks storage"` } s := &storage{} - err := xml.Unmarshal(q.Body, s) - if err != nil { + if err := xml.Unmarshal(q.Body, s); err != nil { return false } /* diff --git a/server/extension/iq_private_metacontacts.go b/server/extension/iq_private_metacontacts.go index 6e897ea..c525a62 100644 --- a/server/extension/iq_private_metacontacts.go +++ b/server/extension/iq_private_metacontacts.go @@ -19,8 +19,7 @@ func (ex *IQPrivateMetacontact) Handle(msg *messages.IQ, q *iqPrivateQuery, clie XMLName xml.Name `xml:"storage:metacontacts storage"` } s := &storage{} - err := xml.Unmarshal(q.Body, s) - if err != nil { + if err := xml.Unmarshal(q.Body, s); err != nil { return false } /* diff --git a/server/extension/iq_private_roster.go b/server/extension/iq_private_roster.go index 090dbc7..554e0b9 100644 --- a/server/extension/iq_private_roster.go +++ b/server/extension/iq_private_roster.go @@ -20,8 +20,7 @@ func (ex *IQPrivateRoster) Handle(msg *messages.IQ, q *iqPrivateQuery, client *u Body []byte `xml:",innerxml"` } r := &roster{} - err := xml.Unmarshal(q.Body, r) - if err != nil { + if err := xml.Unmarshal(q.Body, r); err != nil { return false } diff --git a/server/extension/iq_roster.go b/server/extension/iq_roster.go index 018d109..899ab4f 100644 --- a/server/extension/iq_roster.go +++ b/server/extension/iq_roster.go @@ -25,8 +25,7 @@ func (ex *IQRoster) Get(msg *messages.IQ, client *utils.Client) bool { Body []byte `xml:",innerxml"` } q := &query{} - err := xml.Unmarshal(msg.Body, q) - if err != nil { + if err := xml.Unmarshal(msg.Body, q); err != nil { return false } diff --git a/server/server.go b/server/server.go index 0580770..b8d77a8 100644 --- a/server/server.go +++ b/server/server.go @@ -8,21 +8,24 @@ import ( "github.com/genofire/yaja/model" "github.com/genofire/yaja/server/extension" "github.com/genofire/yaja/server/toclient" + "github.com/genofire/yaja/server/toserver" "github.com/genofire/yaja/server/utils" log "github.com/sirupsen/logrus" "golang.org/x/crypto/acme/autocert" ) type Server struct { - TLSConfig *tls.Config - TLSManager *autocert.Manager - ClientAddr []string - ServerAddr []string - Database *database.State - LoggingClient log.Level - RegisterEnable bool - RegisterDomains []string - Extensions extension.Extensions + TLSConfig *tls.Config + TLSManager *autocert.Manager + ClientAddr []string + ServerAddr []string + Database *database.State + LoggingClient log.Level + LoggingServer log.Level + RegisterEnable bool + RegisterDomains []string + ExtensionsClient extension.Extensions + ExtensionsServer extension.Extensions } func (srv *Server) Start() { @@ -69,15 +72,33 @@ func (srv *Server) listenClient(c2s net.Listener) { func (srv *Server) handleServer(conn net.Conn) { log.Info("new server connection:", conn.RemoteAddr()) + + client := utils.NewClient(conn, srv.LoggingClient) + client.Log = client.Log.WithField("c", "s2s") + + state := toserver.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.ExtensionsServer, client) + + for { + state = state.Process() + if state == nil { + client.Log.Info("disconnect") + client.Close() + return + } + // run next state + } } func (srv *Server) handleClient(conn net.Conn) { log.Info("new client connection:", conn.RemoteAddr()) - client := utils.NewClient(conn, srv.LoggingClient) - state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed, srv.Extensions) + + client := utils.NewClient(conn, srv.LoggingServer) + client.Log = client.Log.WithField("c", "c2s") + + state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed, srv.ExtensionsClient, client) for { - state, client = state.Process(client) + state = state.Process() if state == nil { client.Log.Info("disconnect") client.Close() diff --git a/server/state/connect.go b/server/state/connect.go index e4529fd..2f4baae 100644 --- a/server/state/connect.go +++ b/server/state/connect.go @@ -10,116 +10,107 @@ import ( "golang.org/x/crypto/acme/autocert" ) -// ConnectionStartup return steps through TCP TLS state -func ConnectionStartup(after State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager) State { - tlsupgrade := &TLSUpgrade{ - Next: after, - tlsconfig: tlsconfig, - tlsmgmt: tlsmgmt, - } - stream := &Start{Next: tlsupgrade} - return stream -} - // Start state type Start struct { - Next State + Next State + Client *utils.Client } // Process message -func (state *Start) Process(client *utils.Client) (State, *utils.Client) { - client.Log = client.Log.WithField("state", "stream") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *Start) Process() State { + state.Client.Log = state.Client.Log.WithField("state", "stream") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") - element, err := client.Read() + element, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { - client.Log.Warn("is no stream") - return state, client + state.Client.Log.Warn("is no stream") + return state } for _, attr := range element.Attr { if attr.Name.Local == "to" { - client.JID = &model.JID{Domain: attr.Value} - client.Log = client.Log.WithField("jid", client.JID.Full()) + state.Client.JID = &model.JID{Domain: attr.Value} + state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) } } - if client.JID == nil { - client.Log.Warn("no 'to' domain readed") - return nil, client + if state.Client.JID == nil { + state.Client.Log.Warn("no 'to' domain readed") + return nil } - fmt.Fprintf(client.Conn, ` + fmt.Fprintf(state.Client.Conn, ` `, utils.CreateCookie(), messages.NSClient, messages.NSStream) - fmt.Fprintf(client.Conn, ` + fmt.Fprintf(state.Client.Conn, ` `, messages.NSStream) - return state.Next, client + return state.Next } // TLSUpgrade state type TLSUpgrade struct { - Next State - tlsconfig *tls.Config - tlsmgmt *autocert.Manager + Next State + Client *utils.Client + TLSConfig *tls.Config + TLSManager *autocert.Manager } // Process message -func (state *TLSUpgrade) Process(client *utils.Client) (State, *utils.Client) { - client.Log = client.Log.WithField("state", "tls upgrade") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *TLSUpgrade) Process() State { + state.Client.Log = state.Client.Log.WithField("state", "tls upgrade") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") - element, err := client.Read() + element, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" { - client.Log.Warn("is no starttls") - return state, client + state.Client.Log.Warn("is no starttls", element) + return nil } - fmt.Fprintf(client.Conn, "", messages.NSTLS) + fmt.Fprintf(state.Client.Conn, "", messages.NSTLS) // perform the TLS handshake var tlsConfig *tls.Config - if m := state.tlsmgmt; m != nil { + if m := state.TLSManager; m != nil { var cert *tls.Certificate - cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.JID.Domain}) + cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: state.Client.JID.Domain}) if err != nil { - client.Log.Warn("no cert in tls manger found: ", err) - return nil, client + state.Client.Log.Warn("no cert in tls manger found: ", err) + return nil } tlsConfig = &tls.Config{ Certificates: []tls.Certificate{*cert}, } } if tlsConfig == nil { - tlsConfig = state.tlsconfig + tlsConfig = state.TLSConfig if tlsConfig != nil { - tlsConfig.ServerName = client.JID.Domain + tlsConfig.ServerName = state.Client.JID.Domain } else { - client.Log.Warn("no tls config found: ", err) - return nil, client + state.Client.Log.Warn("no tls config found: ", err) + return nil } } - tlsConn := tls.Server(client.Conn, tlsConfig) + tlsConn := tls.Server(state.Client.Conn, tlsConfig) err = tlsConn.Handshake() if err != nil { - client.Log.Warn("unable to tls handshake: ", err) - return nil, client + state.Client.Log.Warn("unable to tls handshake: ", err) + return nil } // restart the Connection - client.SetConnecting(tlsConn) + state.Client.SetConnecting(tlsConn) - return state.Next, client + return state.Next } diff --git a/server/state/normal.go b/server/state/normal.go new file mode 100644 index 0000000..a2e0245 --- /dev/null +++ b/server/state/normal.go @@ -0,0 +1,49 @@ +package state + +import ( + "github.com/genofire/yaja/server/extension" + "github.com/genofire/yaja/server/utils" +) + +// SendingClient state +type SendingClient struct { + Next State + Client *utils.Client +} + +// Process messages +func (state *SendingClient) Process() State { + state.Client.Log = state.Client.Log.WithField("state", "normal") + state.Client.Log.Debug("sending") + // sending + go func() { + select { + case msg := <-state.Client.Messages: + err := state.Client.Out.Encode(msg) + if err != nil { + state.Client.Log.Warn(err) + } + case <-state.Client.OnClose(): + return + } + }() + state.Client.Log.Debug("receiving") + return state.Next +} + +// ReceivingClient state +type ReceivingClient struct { + Extensions extension.Extensions + Client *utils.Client +} + +// Process messages +func (state *ReceivingClient) Process() State { + element, err := state.Client.Read() + if err != nil { + state.Client.Log.Warn("unable to read: ", err) + return nil + } + state.Extensions.Process(element, state.Client) + return state +} diff --git a/server/state/state.go b/server/state/state.go index ef7f4f5..5ab35a7 100644 --- a/server/state/state.go +++ b/server/state/state.go @@ -4,5 +4,27 @@ import "github.com/genofire/yaja/server/utils" // State processes the stream and moves to the next state type State interface { - Process(client *utils.Client) (State, *utils.Client) + Process() State +} + +// Start state +type Debug struct { + Next State + Client *utils.Client +} + +// Process message +func (state *Debug) Process() State { + state.Client.Log = state.Client.Log.WithField("state", "debug") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") + + element, err := state.Client.Read() + if err != nil { + state.Client.Log.Warn("unable to read: ", err) + return nil + } + state.Client.Log.Info(element) + + return state.Next } diff --git a/server/toclient/connect.go b/server/toclient/connect.go index f8deaa4..8c4a983 100644 --- a/server/toclient/connect.go +++ b/server/toclient/connect.go @@ -16,51 +16,60 @@ import ( ) // ConnectionStartup return steps through TCP TLS state -func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions []extension.Extension) state.State { - receiving := &ReceivingClient{Extensions: extensions} - sending := &SendingClient{Next: receiving} - authedstream := &AuthedStream{Next: sending} - authedstart := &AuthedStart{Next: authedstream} +func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions extension.Extensions, c *utils.Client) state.State { + receiving := &state.ReceivingClient{Extensions: extensions, Client: c} + sending := &state.SendingClient{Next: receiving, Client: c} + authedstream := &AuthedStream{Next: sending, Client: c} + authedstart := &AuthedStart{Next: authedstream, Client: c} tlsauth := &SASLAuth{ Next: authedstart, + Client: c, database: db, domainRegisterAllowed: registerAllowed, } tlsstream := &TLSStream{ - Next: tlsauth, + Next: tlsauth, + Client: c, domainRegisterAllowed: registerAllowed, } - return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt) + tlsupgrade := &state.TLSUpgrade{ + Next: tlsstream, + Client: c, + TLSConfig: tlsconfig, + TLSManager: tlsmgmt, + } + return &state.Start{Next: tlsupgrade, Client: c} } // TLSStream state type TLSStream struct { Next state.State + Client *utils.Client domainRegisterAllowed utils.DomainRegisterAllowed } // Process messages -func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "tls stream") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *TLSStream) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "tls stream") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") - element, err := client.Read() + element, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { - client.Log.Warn("is no stream") - return state, client + state.Client.Log.Warn("is no stream") + return state } - fmt.Fprintf(client.Conn, ` + fmt.Fprintf(state.Client.Conn, ` `, utils.CreateCookie(), messages.NSClient, messages.NSStream) - if state.domainRegisterAllowed(client.JID) { - fmt.Fprintf(client.Conn, ` + if state.domainRegisterAllowed(state.Client.JID) { + fmt.Fprintf(state.Client.Conn, ` PLAIN @@ -68,7 +77,7 @@ func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Clien `, messages.NSSASL, messages.NSFeaturesIQRegister) } else { - fmt.Fprintf(client.Conn, ` + fmt.Fprintf(state.Client.Conn, ` PLAIN @@ -76,124 +85,129 @@ func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Clien messages.NSSASL) } - return state.Next, client + return state.Next } // SASLAuth state type SASLAuth struct { Next state.State + Client *utils.Client database *database.State domainRegisterAllowed utils.DomainRegisterAllowed } // Process messages -func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "sasl auth") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *SASLAuth) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "sasl auth") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") // read the full auth stanza - element, err := client.Read() + element, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } var auth messages.SASLAuth - if err = client.In.DecodeElement(&auth, element); err != nil { - client.Log.Info("start substate for registration") + if err = state.Client.In.DecodeElement(&auth, element); err != nil { + state.Client.Log.Info("start substate for registration") return &RegisterFormRequest{ + Next: &RegisterRequest{ + Next: state.Next, + Client: state.Client, + database: state.database, + domainRegisterAllowed: state.domainRegisterAllowed, + }, + Client: state.Client, element: element, domainRegisterAllowed: state.domainRegisterAllowed, - Next: &RegisterRequest{ - domainRegisterAllowed: state.domainRegisterAllowed, - database: state.database, - Next: state.Next, - }, - }, client + } } data, err := base64.StdEncoding.DecodeString(auth.Body) if err != nil { - client.Log.Warn("body decode: ", err) - return nil, client + state.Client.Log.Warn("body decode: ", err) + return nil } info := strings.Split(string(data), "\x00") - // should check that info[1] starts with client.JID - client.JID.Local = info[1] - client.Log = client.Log.WithField("jid", client.JID.Full()) - success, err := state.database.Authenticate(client.JID, info[2]) + // should check that info[1] starts with state.Client.JID + state.Client.JID.Local = info[1] + state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) + success, err := state.database.Authenticate(state.Client.JID, info[2]) if err != nil { - client.Log.Warn("auth: ", err) - return nil, client + state.Client.Log.Warn("auth: ", err) + return nil } if success { - client.Log.Info("success auth") - fmt.Fprintf(client.Conn, "", messages.NSSASL) - return state.Next, client + state.Client.Log.Info("success auth") + fmt.Fprintf(state.Client.Conn, "", messages.NSSASL) + return state.Next } - client.Log.Warn("failed auth") - fmt.Fprintf(client.Conn, "", messages.NSSASL) - return nil, client + state.Client.Log.Warn("failed auth") + fmt.Fprintf(state.Client.Conn, "", messages.NSSASL) + return nil } // AuthedStart state type AuthedStart struct { - Next state.State + Next state.State + Client *utils.Client } // Process messages -func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "authed started") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *AuthedStart) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "authed started") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") - _, err := client.Read() + _, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } - fmt.Fprintf(client.Conn, ` + fmt.Fprintf(state.Client.Conn, ` `, utils.CreateCookie(), messages.NSClient, messages.NSStream) - fmt.Fprintf(client.Conn, ` + fmt.Fprintf(state.Client.Conn, ` `, messages.NSBind) - return state.Next, client + return state.Next } // AuthedStream state type AuthedStream struct { - Next state.State + Next state.State + Client *utils.Client } // Process messages -func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "authed stream") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *AuthedStream) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "authed stream") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") // check that it's a bind request // read bind request - element, err := client.Read() + element, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } var msg messages.IQ - if err = client.In.DecodeElement(&msg, element); err != nil { - client.Log.Warn("is no iq: ", err) - return nil, client + if err = state.Client.In.DecodeElement(&msg, element); err != nil { + state.Client.Log.Warn("is no iq: ", err) + return nil } if msg.Type != messages.IQTypeSet { - client.Log.Warn("is no set iq") - return nil, client + state.Client.Log.Warn("is no set iq") + return nil } if msg.Error != nil { - client.Log.Warn("iq with error: ", msg.Error.Code) - return nil, client + state.Client.Log.Warn("iq with error: ", msg.Error.Code) + return nil } type query struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` @@ -202,26 +216,26 @@ func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Cl q := &query{} err = xml.Unmarshal(msg.Body, q) if err != nil { - client.Log.Warn("is no iq bind: ", err) - return nil, client + state.Client.Log.Warn("is no iq bind: ", err) + return nil } if q.Resource == "" { - client.JID.Resource = makeResource() + state.Client.JID.Resource = makeResource() } else { - client.JID.Resource = q.Resource + state.Client.JID.Resource = q.Resource } - client.Log = client.Log.WithField("jid", client.JID.Full()) - client.Out.Encode(&messages.IQ{ + state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) + state.Client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, - To: client.JID.String(), - From: client.JID.Domain, + To: state.Client.JID.String(), + From: state.Client.JID.Domain, ID: msg.ID, Body: []byte(fmt.Sprintf( ` %s `, - messages.NSBind, client.JID.Full())), + messages.NSBind, state.Client.JID.Full())), }) - return state.Next, client + return state.Next } diff --git a/server/toclient/normal.go b/server/toclient/normal.go deleted file mode 100644 index b066781..0000000 --- a/server/toclient/normal.go +++ /dev/null @@ -1,48 +0,0 @@ -package toclient - -import ( - "github.com/genofire/yaja/server/extension" - "github.com/genofire/yaja/server/state" - "github.com/genofire/yaja/server/utils" -) - -// SendingClient state -type SendingClient struct { - Next state.State -} - -// Process messages -func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "normal") - client.Log.Debug("sending") - // sending - go func() { - select { - case msg := <-client.Messages: - err := client.Out.Encode(msg) - if err != nil { - client.Log.Warn(err) - } - case <-client.OnClose(): - return - } - }() - client.Log.Debug("receiving") - return state.Next, client -} - -// ReceivingClient state -type ReceivingClient struct { - Extensions extension.Extensions -} - -// Process messages -func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) { - element, err := client.Read() - if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client - } - state.Extensions.Process(element, client) - return state, client -} diff --git a/server/toclient/register.go b/server/toclient/register.go index 585a6f6..f2694ac 100644 --- a/server/toclient/register.go +++ b/server/toclient/register.go @@ -13,33 +13,34 @@ import ( type RegisterFormRequest struct { Next state.State + Client *utils.Client domainRegisterAllowed utils.DomainRegisterAllowed element *xml.StartElement } // Process message -func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "register form request") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *RegisterFormRequest) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "register form request") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") - if !state.domainRegisterAllowed(client.JID) { - client.Log.Error("unpossible to reach this state, register on this domain is not allowed") - return nil, client + if !state.domainRegisterAllowed(state.Client.JID) { + state.Client.Log.Error("unpossible to reach this state, register on this domain is not allowed") + return nil } var msg messages.IQ - if err := client.In.DecodeElement(&msg, state.element); err != nil { - client.Log.Warn("is no iq: ", err) - return state, client + if err := state.Client.In.DecodeElement(&msg, state.element); err != nil { + state.Client.Log.Warn("is no iq: ", err) + return state } if msg.Type != messages.IQTypeGet { - client.Log.Warn("is no get iq") - return state, client + state.Client.Log.Warn("is no get iq") + return state } if msg.Error != nil { - client.Log.Warn("iq with error: ", msg.Error.Code) - return state, client + state.Client.Log.Warn("iq with error: ", msg.Error.Code) + return state } type query struct { XMLName xml.Name `xml:"query"` @@ -48,13 +49,13 @@ func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *u err := xml.Unmarshal(msg.Body, q) if q.XMLName.Space != messages.NSIQRegister || err != nil { - client.Log.Warn("is no iq register: ", err) - return nil, client + state.Client.Log.Warn("is no iq register: ", err) + return nil } - client.Out.Encode(&messages.IQ{ + state.Client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, - To: client.JID.String(), - From: client.JID.Domain, + To: state.Client.JID.String(), + From: state.Client.JID.Domain, ID: msg.ID, Body: []byte(fmt.Sprintf(` Choose a username and password for use with this service. @@ -63,43 +64,44 @@ func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *u `, messages.NSIQRegister)), }) - return state.Next, client + return state.Next } type RegisterRequest struct { Next state.State + Client *utils.Client database *database.State domainRegisterAllowed utils.DomainRegisterAllowed } // Process message -func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils.Client) { - client.Log = client.Log.WithField("state", "register request") - client.Log.Debug("running") - defer client.Log.Debug("leave") +func (state *RegisterRequest) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "register request") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") - if !state.domainRegisterAllowed(client.JID) { - client.Log.Error("unpossible to reach this state, register on this domain is not allowed") - return nil, client + if !state.domainRegisterAllowed(state.Client.JID) { + state.Client.Log.Error("unpossible to reach this state, register on this domain is not allowed") + return nil } - element, err := client.Read() + element, err := state.Client.Read() if err != nil { - client.Log.Warn("unable to read: ", err) - return nil, client + state.Client.Log.Warn("unable to read: ", err) + return nil } var msg messages.IQ - if err = client.In.DecodeElement(&msg, element); err != nil { - client.Log.Warn("is no iq: ", err) - return state, client + if err = state.Client.In.DecodeElement(&msg, element); err != nil { + state.Client.Log.Warn("is no iq: ", err) + return state } if msg.Type != messages.IQTypeGet { - client.Log.Warn("is no get iq") - return state, client + state.Client.Log.Warn("is no get iq") + return state } if msg.Error != nil { - client.Log.Warn("iq with error: ", msg.Error.Code) - return state, client + state.Client.Log.Warn("iq with error: ", msg.Error.Code) + return state } type query struct { XMLName xml.Name `xml:"query"` @@ -109,19 +111,19 @@ func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils q := &query{} err = xml.Unmarshal(msg.Body, q) if err != nil { - client.Log.Warn("is no iq register: ", err) - return nil, client + state.Client.Log.Warn("is no iq register: ", err) + return nil } - client.JID.Local = q.Username - client.Log = client.Log.WithField("jid", client.JID.Full()) - account := model.NewAccount(client.JID, q.Password) + state.Client.JID.Local = q.Username + state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) + account := model.NewAccount(state.Client.JID, q.Password) err = state.database.AddAccount(account) if err != nil { - client.Out.Encode(&messages.IQ{ + state.Client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, - To: client.JID.String(), - From: client.JID.Domain, + To: state.Client.JID.String(), + From: state.Client.JID.Domain, ID: msg.ID, Body: []byte(fmt.Sprintf(` %s @@ -136,16 +138,16 @@ func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils }, }, }) - client.Log.Warn("database error: ", err) - return state, client + state.Client.Log.Warn("database error: ", err) + return state } - client.Out.Encode(&messages.IQ{ + state.Client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, - To: client.JID.String(), - From: client.JID.Domain, + To: state.Client.JID.String(), + From: state.Client.JID.Domain, ID: msg.ID, }) - client.Log.Infof("registered client %s", client.JID.Bare()) - return state.Next, client + state.Client.Log.Infof("registered client %s", state.Client.JID.Bare()) + return state.Next } diff --git a/server/toserver/connect.go b/server/toserver/connect.go new file mode 100644 index 0000000..ea28565 --- /dev/null +++ b/server/toserver/connect.go @@ -0,0 +1,141 @@ +package toserver + +import ( + "crypto/tls" + "encoding/base64" + "encoding/xml" + "fmt" + + "github.com/genofire/yaja/database" + "github.com/genofire/yaja/messages" + "github.com/genofire/yaja/server/extension" + "github.com/genofire/yaja/server/state" + "github.com/genofire/yaja/server/utils" + "golang.org/x/crypto/acme/autocert" +) + +// ConnectionStartup return steps through TCP TLS state +func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, extensions extension.Extensions, c *utils.Client) state.State { + receiving := &state.ReceivingClient{Extensions: extensions, Client: c} + sending := &state.SendingClient{Next: receiving, Client: c} + tlsstream := &TLSStream{ + Next: sending, + Client: c, + } + tlsupgrade := &state.TLSUpgrade{ + Next: tlsstream, + Client: c, + TLSConfig: tlsconfig, + TLSManager: tlsmgmt, + } + dail := &Dailback{ + Next: tlsupgrade, + Client: c, + } + return &state.Start{Next: dail, Client: c} +} + +// TLSStream state +type Dailback struct { + Next state.State + Client *utils.Client +} + +// Process messages +func (state *Dailback) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "dialback") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") + + element, err := state.Client.Read() + if err != nil { + state.Client.Log.Warn("unable to read: ", err) + return nil + } + + // dailback encode + type dailback struct { + XMLName xml.Name `xml:"urn:xmpp:ping ping"` + } + db := &dailback{} + if err = state.Client.In.DecodeElement(db, element); err != nil { + return state.Next + } + + state.Client.Log.Info(db) + return state.Next +} + +// TLSStream state +type TLSStream struct { + Next state.State + Client *utils.Client + domainRegisterAllowed utils.DomainRegisterAllowed +} + +// Process messages +func (state *TLSStream) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "tls stream") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") + + element, err := state.Client.Read() + if err != nil { + state.Client.Log.Warn("unable to read: ", err) + return nil + } + if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { + state.Client.Log.Warn("is no stream") + return state + } + + fmt.Fprintf(state.Client.Conn, ` + `, + utils.CreateCookie(), messages.NSClient, messages.NSStream) + + fmt.Fprintf(state.Client.Conn, ` + + EXTERNAL + + `, + messages.NSSASL) + + return state.Next +} + +// SASLAuth state +type SASLAuth struct { + Next state.State + Client *utils.Client + database *database.State + domainRegisterAllowed utils.DomainRegisterAllowed +} + +// Process messages +func (state *SASLAuth) Process() state.State { + state.Client.Log = state.Client.Log.WithField("state", "sasl auth") + state.Client.Log.Debug("running") + defer state.Client.Log.Debug("leave") + + // read the full auth stanza + element, err := state.Client.Read() + if err != nil { + state.Client.Log.Warn("unable to read: ", err) + return nil + } + var auth messages.SASLAuth + if err = state.Client.In.DecodeElement(&auth, element); err != nil { + return nil + } + data, err := base64.StdEncoding.DecodeString(auth.Body) + if err != nil { + state.Client.Log.Warn("body decode: ", err) + return nil + } + + state.Client.Log.Debug(auth.Mechanism, string(data)) + + state.Client.Log.Info("success auth") + fmt.Fprintf(state.Client.Conn, "", messages.NSSASL) + return state.Next +} diff --git a/server/utils/client.go b/server/utils/client.go index 1240cc5..eba19e4 100644 --- a/server/utils/client.go +++ b/server/utils/client.go @@ -30,7 +30,7 @@ func NewClient(conn net.Conn, level log.Level) *Client { Log: log.NewEntry(logger), In: xml.NewDecoder(conn), Out: xml.NewEncoder(conn), - Messages: make(chan interface{}, 1000), + Messages: make(chan interface{}), close: make(chan interface{}), } return client