package toclient import ( "crypto/tls" "encoding/base64" "fmt" "strings" "dev.sum7.eu/genofire/yaja/database" "dev.sum7.eu/genofire/yaja/messages" "dev.sum7.eu/genofire/yaja/model" "dev.sum7.eu/genofire/yaja/server/extension" "dev.sum7.eu/genofire/yaja/server/state" "dev.sum7.eu/genofire/yaja/server/utils" "golang.org/x/crypto/acme/autocert" ) // ConnectionStartup return steps through TCP TLS state func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions extension.Extensions, c *utils.Client) state.State { receiving := &state.ReceivingClient{Extensions: extensions, Client: c} sending := &state.SendingClient{Next: receiving, Client: c} authedstream := &AuthedStream{Next: sending, Client: c} authedstart := &AuthedStart{Next: authedstream, Client: c} tlsauth := &SASLAuth{ Next: authedstart, Client: c, database: db, domainRegisterAllowed: registerAllowed, } tlsstream := &TLSStream{ Next: tlsauth, Client: c, domainRegisterAllowed: registerAllowed, } tlsupgrade := &state.TLSUpgrade{ Next: tlsstream, Client: c, TLSConfig: tlsconfig, TLSManager: tlsmgmt, } return &state.Start{Next: tlsupgrade, Client: c} } // TLSStream state type TLSStream struct { Next state.State Client *utils.Client domainRegisterAllowed utils.DomainRegisterAllowed } // Process messages func (state *TLSStream) Process() state.State { state.Client.Log = state.Client.Log.WithField("state", "tls stream") state.Client.Log.Debug("running") defer state.Client.Log.Debug("leave") element, err := state.Client.Read() if err != nil { state.Client.Log.Warn("unable to read: ", err) return nil } if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { state.Client.Log.Warn("is no stream") return state } if state.domainRegisterAllowed(state.Client.JID) { fmt.Fprintf(state.Client.Conn, ` PLAIN `, messages.CreateCookie(), messages.NSClient, messages.NSStream, messages.NSSASL, messages.NSFeaturesIQRegister) } else { fmt.Fprintf(state.Client.Conn, ` PLAIN `, messages.CreateCookie(), messages.NSClient, messages.NSStream, messages.NSSASL) } return state.Next } // SASLAuth state type SASLAuth struct { Next state.State Client *utils.Client database *database.State domainRegisterAllowed utils.DomainRegisterAllowed } // Process messages func (state *SASLAuth) Process() state.State { state.Client.Log = state.Client.Log.WithField("state", "sasl auth") state.Client.Log.Debug("running") defer state.Client.Log.Debug("leave") // read the full auth stanza element, err := state.Client.Read() if err != nil { state.Client.Log.Warn("unable to read: ", err) return nil } var auth messages.SASLAuth if err = state.Client.In.DecodeElement(&auth, element); err != nil { state.Client.Log.Info("start substate for registration") return &RegisterFormRequest{ Next: &RegisterRequest{ Next: state.Next, Client: state.Client, database: state.database, domainRegisterAllowed: state.domainRegisterAllowed, }, Client: state.Client, element: element, domainRegisterAllowed: state.domainRegisterAllowed, } } data, err := base64.StdEncoding.DecodeString(auth.Body) if err != nil { state.Client.Log.Warn("body decode: ", err) return nil } info := strings.Split(string(data), "\x00") // should check that info[1] starts with state.Client.JID state.Client.JID.Local = info[1] state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) success, err := state.database.Authenticate(state.Client.JID, info[2]) if err != nil { state.Client.Log.Warn("auth: ", err) return nil } if success { state.Client.Log.Info("success auth") fmt.Fprintf(state.Client.Conn, "", messages.NSSASL) return state.Next } state.Client.Log.Warn("failed auth") fmt.Fprintf(state.Client.Conn, "", messages.NSSASL) return nil } // AuthedStart state type AuthedStart struct { Next state.State Client *utils.Client } // Process messages func (state *AuthedStart) Process() state.State { state.Client.Log = state.Client.Log.WithField("state", "authed started") state.Client.Log.Debug("running") defer state.Client.Log.Debug("leave") _, err := state.Client.Read() if err != nil { state.Client.Log.Warn("unable to read: ", err) return nil } fmt.Fprintf(state.Client.Conn, ` `, messages.NSStream, state.Client.JID.Domain, messages.CreateCookie(), messages.NSClient, messages.NSBind) return state.Next } // AuthedStream state type AuthedStream struct { Next state.State Client *utils.Client } // Process messages func (state *AuthedStream) Process() state.State { state.Client.Log = state.Client.Log.WithField("state", "authed stream") state.Client.Log.Debug("running") defer state.Client.Log.Debug("leave") // check that it's a bind request // read bind request element, err := state.Client.Read() if err != nil { state.Client.Log.Warn("unable to read: ", err) return nil } var msg messages.IQClient if err = state.Client.In.DecodeElement(&msg, element); err != nil { state.Client.Log.Warn("is no iq: ", err) return nil } if msg.Type != messages.IQTypeSet { state.Client.Log.Warn("is no set iq") return nil } if msg.Error != nil { state.Client.Log.Warn("iq with error: ", msg.Error.Code) return nil } if msg.Bind == nil { state.Client.Log.Warn("is no iq bind: ", err) return nil } if msg.Bind.Resource == "" { state.Client.JID.Resource = makeResource() } else { state.Client.JID.Resource = msg.Bind.Resource } state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) state.Client.Out.Encode(&messages.IQClient{ Type: messages.IQTypeResult, To: state.Client.JID, From: model.NewJID(state.Client.JID.Domain), ID: msg.ID, Bind: &messages.Bind{JID: state.Client.JID}, }) return state.Next }