sum7
/
yaja
Archived
1
0
Fork 0
This repository has been archived on 2020-09-27. You can view files and clone it, but cannot push or open issues or pull requests.
yaja/server/toclient/connect.go

235 lines
6.7 KiB
Go
Raw Normal View History

2017-12-16 23:20:46 +01:00
package toclient
import (
"crypto/tls"
"encoding/base64"
"fmt"
"strings"
2018-02-07 15:34:18 +01:00
"dev.sum7.eu/genofire/yaja/database"
"dev.sum7.eu/genofire/yaja/messages"
"dev.sum7.eu/genofire/yaja/server/extension"
"dev.sum7.eu/genofire/yaja/server/state"
"dev.sum7.eu/genofire/yaja/server/utils"
2017-12-16 23:20:46 +01:00
"golang.org/x/crypto/acme/autocert"
)
// ConnectionStartup return steps through TCP TLS state
2017-12-17 17:50:51 +01:00
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions extension.Extensions, c *utils.Client) state.State {
receiving := &state.ReceivingClient{Extensions: extensions, Client: c}
sending := &state.SendingClient{Next: receiving, Client: c}
authedstream := &AuthedStream{Next: sending, Client: c}
authedstart := &AuthedStart{Next: authedstream, Client: c}
2017-12-16 23:20:46 +01:00
tlsauth := &SASLAuth{
Next: authedstart,
2017-12-17 17:50:51 +01:00
Client: c,
2017-12-16 23:20:46 +01:00
database: db,
domainRegisterAllowed: registerAllowed,
}
tlsstream := &TLSStream{
2017-12-17 17:50:51 +01:00
Next: tlsauth,
Client: c,
2017-12-16 23:20:46 +01:00
domainRegisterAllowed: registerAllowed,
}
2017-12-17 17:50:51 +01:00
tlsupgrade := &state.TLSUpgrade{
Next: tlsstream,
Client: c,
TLSConfig: tlsconfig,
TLSManager: tlsmgmt,
}
return &state.Start{Next: tlsupgrade, Client: c}
2017-12-16 23:20:46 +01:00
}
// TLSStream state
type TLSStream struct {
Next state.State
2017-12-17 17:50:51 +01:00
Client *utils.Client
2017-12-16 23:20:46 +01:00
domainRegisterAllowed utils.DomainRegisterAllowed
}
// Process messages
2017-12-17 17:50:51 +01:00
func (state *TLSStream) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "tls stream")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
2017-12-16 23:20:46 +01:00
2017-12-17 17:50:51 +01:00
element, err := state.Client.Read()
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to read: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("is no stream")
return state
2017-12-16 23:20:46 +01:00
}
2017-12-17 17:50:51 +01:00
if state.domainRegisterAllowed(state.Client.JID) {
2018-02-07 15:34:18 +01:00
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>
<stream:features>
<register xmlns='%s'/>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
</stream:features>`,
2018-02-11 19:35:32 +01:00
messages.CreateCookie(), messages.NSClient, messages.NSStream,
2017-12-16 23:20:46 +01:00
messages.NSSASL, messages.NSFeaturesIQRegister)
} else {
2018-02-07 15:34:18 +01:00
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>
<stream:features>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
</stream:features>`,
2018-02-11 19:35:32 +01:00
messages.CreateCookie(), messages.NSClient, messages.NSStream,
2017-12-16 23:20:46 +01:00
messages.NSSASL)
}
2017-12-17 17:50:51 +01:00
return state.Next
2017-12-16 23:20:46 +01:00
}
// SASLAuth state
type SASLAuth struct {
Next state.State
2017-12-17 17:50:51 +01:00
Client *utils.Client
2017-12-16 23:20:46 +01:00
database *database.State
domainRegisterAllowed utils.DomainRegisterAllowed
}
// Process messages
2017-12-17 17:50:51 +01:00
func (state *SASLAuth) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "sasl auth")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
2017-12-16 23:20:46 +01:00
// read the full auth stanza
2017-12-17 17:50:51 +01:00
element, err := state.Client.Read()
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to read: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
var auth messages.SASLAuth
2017-12-17 17:50:51 +01:00
if err = state.Client.In.DecodeElement(&auth, element); err != nil {
state.Client.Log.Info("start substate for registration")
2017-12-16 23:20:46 +01:00
return &RegisterFormRequest{
Next: &RegisterRequest{
Next: state.Next,
2017-12-17 17:50:51 +01:00
Client: state.Client,
database: state.database,
domainRegisterAllowed: state.domainRegisterAllowed,
2017-12-16 23:20:46 +01:00
},
2017-12-17 17:50:51 +01:00
Client: state.Client,
element: element,
domainRegisterAllowed: state.domainRegisterAllowed,
}
2017-12-16 23:20:46 +01:00
}
data, err := base64.StdEncoding.DecodeString(auth.Body)
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("body decode: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
info := strings.Split(string(data), "\x00")
2017-12-17 17:50:51 +01:00
// should check that info[1] starts with state.Client.JID
state.Client.JID.Local = info[1]
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
success, err := state.database.Authenticate(state.Client.JID, info[2])
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("auth: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
if success {
2017-12-17 17:50:51 +01:00
state.Client.Log.Info("success auth")
fmt.Fprintf(state.Client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
return state.Next
2017-12-16 23:20:46 +01:00
}
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("failed auth")
fmt.Fprintf(state.Client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
return nil
2017-12-16 23:20:46 +01:00
}
// AuthedStart state
type AuthedStart struct {
2017-12-17 17:50:51 +01:00
Next state.State
Client *utils.Client
2017-12-16 23:20:46 +01:00
}
// Process messages
2017-12-17 17:50:51 +01:00
func (state *AuthedStart) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "authed started")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
2017-12-16 23:20:46 +01:00
2017-12-17 17:50:51 +01:00
_, err := state.Client.Read()
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to read: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
2017-12-17 17:50:51 +01:00
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
2018-02-07 15:34:18 +01:00
<stream:stream xmlns:stream='%s' xml:lang='en' from='%s' id='%x' version='1.0' xmlns='%s'>
<stream:features>
<bind xmlns='%s'>
<required/>
</bind>
</stream:features>`,
2018-02-11 19:35:32 +01:00
messages.NSStream, state.Client.JID.Domain, messages.CreateCookie(), messages.NSClient,
2017-12-16 23:20:46 +01:00
messages.NSBind)
2017-12-17 17:50:51 +01:00
return state.Next
2017-12-16 23:20:46 +01:00
}
// AuthedStream state
type AuthedStream struct {
2017-12-17 17:50:51 +01:00
Next state.State
Client *utils.Client
2017-12-16 23:20:46 +01:00
}
// Process messages
2017-12-17 17:50:51 +01:00
func (state *AuthedStream) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "authed stream")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
2017-12-16 23:20:46 +01:00
// check that it's a bind request
// read bind request
2017-12-17 17:50:51 +01:00
element, err := state.Client.Read()
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to read: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
var msg messages.IQClient
2017-12-17 17:50:51 +01:00
if err = state.Client.In.DecodeElement(&msg, element); err != nil {
state.Client.Log.Warn("is no iq: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
if msg.Type != messages.IQTypeSet {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("is no set iq")
return nil
2017-12-16 23:20:46 +01:00
}
if msg.Error != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("iq with error: ", msg.Error.Code)
return nil
2017-12-16 23:20:46 +01:00
}
2018-02-10 13:34:42 +01:00
if msg.Bind == nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("is no iq bind: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
2018-02-10 13:34:42 +01:00
if msg.Bind.Resource == "" {
2017-12-17 17:50:51 +01:00
state.Client.JID.Resource = makeResource()
2017-12-16 23:20:46 +01:00
} else {
2018-02-10 13:34:42 +01:00
state.Client.JID.Resource = msg.Bind.Resource
2017-12-16 23:20:46 +01:00
}
2017-12-17 17:50:51 +01:00
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
state.Client.Out.Encode(&messages.IQClient{
2017-12-16 23:20:46 +01:00
Type: messages.IQTypeResult,
2018-02-10 13:34:42 +01:00
To: state.Client.JID,
2018-02-13 20:05:18 +01:00
From: messages.NewJID(state.Client.JID.Domain),
2017-12-16 23:20:46 +01:00
ID: msg.ID,
2018-02-10 13:34:42 +01:00
Bind: &messages.Bind{JID: state.Client.JID},
2017-12-16 23:20:46 +01:00
})
2017-12-17 17:50:51 +01:00
return state.Next
2017-12-16 23:20:46 +01:00
}