package toclient import ( "crypto/tls" "encoding/base64" "encoding/xml" "fmt" "strings" "github.com/genofire/yaja/database" "github.com/genofire/yaja/messages" "github.com/genofire/yaja/server/extension" "github.com/genofire/yaja/server/state" "github.com/genofire/yaja/server/utils" "golang.org/x/crypto/acme/autocert" ) // ConnectionStartup return steps through TCP TLS state func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed) state.State { receiving := &ReceivingClient{} sending := &SendingClient{Next: receiving} authedstream := &AuthedStream{Next: sending} authedstart := &AuthedStart{Next: authedstream} tlsauth := &SASLAuth{ Next: authedstart, database: db, domainRegisterAllowed: registerAllowed, } tlsstream := &TLSStream{ Next: tlsauth, domainRegisterAllowed: registerAllowed, } return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt) } // TLSStream state type TLSStream struct { Next state.State domainRegisterAllowed utils.DomainRegisterAllowed } // Process messages func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) { client.Log = client.Log.WithField("state", "tls stream") client.Log.Debug("running") defer client.Log.Debug("leave") element, err := client.Read() if err != nil { client.Log.Warn("unable to read: ", err) return nil, client } if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { client.Log.Warn("is no stream") return state, client } fmt.Fprintf(client.Conn, ` `, utils.CreateCookie(), messages.NSClient, messages.NSStream) if state.domainRegisterAllowed(client.JID) { fmt.Fprintf(client.Conn, ` PLAIN `, messages.NSSASL, messages.NSFeaturesIQRegister) } else { fmt.Fprintf(client.Conn, ` PLAIN `, messages.NSSASL) } return state.Next, client } // SASLAuth state type SASLAuth struct { Next state.State database *database.State domainRegisterAllowed utils.DomainRegisterAllowed } // Process messages func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) { client.Log = client.Log.WithField("state", "sasl auth") client.Log.Debug("running") defer client.Log.Debug("leave") // read the full auth stanza element, err := client.Read() if err != nil { client.Log.Warn("unable to read: ", err) return nil, client } var auth messages.SASLAuth if err = client.In.DecodeElement(&auth, element); err != nil { client.Log.Info("start substate for registration") return &RegisterFormRequest{ element: element, domainRegisterAllowed: state.domainRegisterAllowed, Next: &RegisterRequest{ domainRegisterAllowed: state.domainRegisterAllowed, database: state.database, Next: state.Next, }, }, client } data, err := base64.StdEncoding.DecodeString(auth.Body) if err != nil { client.Log.Warn("body decode: ", err) return nil, client } info := strings.Split(string(data), "\x00") // should check that info[1] starts with client.JID client.JID.Local = info[1] client.Log = client.Log.WithField("jid", client.JID.Full()) success, err := state.database.Authenticate(client.JID, info[2]) if err != nil { client.Log.Warn("auth: ", err) return nil, client } if success { client.Log.Info("success auth") fmt.Fprintf(client.Conn, "", messages.NSSASL) return state.Next, client } client.Log.Warn("failed auth") fmt.Fprintf(client.Conn, "", messages.NSSASL) return nil, client } // AuthedStart state type AuthedStart struct { Next state.State } // Process messages func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) { client.Log = client.Log.WithField("state", "authed started") client.Log.Debug("running") defer client.Log.Debug("leave") _, err := client.Read() if err != nil { client.Log.Warn("unable to read: ", err) return nil, client } fmt.Fprintf(client.Conn, ` `, utils.CreateCookie(), messages.NSClient, messages.NSStream) fmt.Fprintf(client.Conn, ` `, messages.NSBind) return state.Next, client } // AuthedStream state type AuthedStream struct { Next state.State } // Process messages func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) { client.Log = client.Log.WithField("state", "authed stream") client.Log.Debug("running") defer client.Log.Debug("leave") // check that it's a bind request // read bind request element, err := client.Read() if err != nil { client.Log.Warn("unable to read: ", err) return nil, client } var msg messages.IQ if err = client.In.DecodeElement(&msg, element); err != nil { client.Log.Warn("is no iq: ", err) return nil, client } if msg.Type != messages.IQTypeSet { client.Log.Warn("is no set iq") return nil, client } if msg.Error != nil { client.Log.Warn("iq with error: ", msg.Error.Code) return nil, client } type query struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` Resource string `xml:"resource"` } q := &query{} err = xml.Unmarshal(msg.Body, q) if err != nil { client.Log.Warn("is no iq bind: ", err) return nil, client } if q.Resource == "" { client.JID.Resource = makeResource() } else { client.JID.Resource = q.Resource } client.Log = client.Log.WithField("jid", client.JID.Full()) client.Out.Encode(&messages.IQ{ Type: messages.IQTypeResult, ID: msg.ID, Body: []byte(fmt.Sprintf( ` %s `, messages.NSBind, client.JID.Full())), }) return state.Next, client } // SendingClient state type SendingClient struct { Next state.State } // Process messages func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) { client.Log = client.Log.WithField("state", "normal") client.Log.Debug("sending") // sending go func() { select { case msg := <-client.Messages: err := client.Out.Encode(msg) client.Log.Info(err) case <-client.OnClose(): return } }() client.Log.Debug("receiving") return state.Next, client } // ReceivingClient state type ReceivingClient struct { Extensions []extension.Extension } // Process messages func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) { element, err := client.Read() if err != nil { client.Log.Warn("unable to read: ", err) return nil, client } count := 0 for _, extension := range state.Extensions { if extension.Process(element, client) { count++ } } if count != 1 { client.Log.WithField("extension", count).Debug(element) } return state, client }