sum7
/
yaja
Archived
1
0
Fork 0

move client to state attr + s2s idea

This commit is contained in:
Martin Geno 2017-12-17 17:50:51 +01:00
parent 1e2e578076
commit e474f460aa
No known key found for this signature in database
GPG Key ID: F0D39A37E925E941
21 changed files with 498 additions and 305 deletions

View File

@ -25,12 +25,13 @@ import (
var configPath string var configPath string
var ( var (
configData = &config.Config{} configData = &config.Config{}
db = &database.State{} db = &database.State{}
statesaveWorker *worker.Worker statesaveWorker *worker.Worker
srv *server.Server srv *server.Server
certs *tls.Config certs *tls.Config
extensions extension.Extensions extensionsClient extension.Extensions
extensionsServer extension.Extensions
) )
// serverCmd represents the serve command // serverCmd represents the serve command
@ -39,16 +40,14 @@ var serverCmd = &cobra.Command{
Short: "Runs the yaja server", Short: "Runs the yaja server",
Example: "yaja serve -c /etc/yaja.conf", Example: "yaja serve -c /etc/yaja.conf",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
var err error
err = file.ReadTOML(configPath, configData) if err := file.ReadTOML(configPath, configData); err != nil {
if err != nil {
log.Fatal("unable to load config file:", err) log.Fatal("unable to load config file:", err)
} }
log.SetLevel(configData.Logging.Level) log.SetLevel(configData.Logging.Level)
err = file.ReadJSON(configData.StatePath, db) if err := file.ReadJSON(configData.StatePath, db); err != nil {
if err != nil {
log.Warn("unable to load state file:", err) log.Warn("unable to load state file:", err)
} }
@ -76,14 +75,16 @@ var serverCmd = &cobra.Command{
} }
srv = &server.Server{ srv = &server.Server{
TLSManager: &m, TLSManager: &m,
Database: db, Database: db,
ClientAddr: configData.Address.Client, ClientAddr: configData.Address.Client,
ServerAddr: configData.Address.Server, ServerAddr: configData.Address.Server,
LoggingClient: configData.Logging.LevelClient, LoggingClient: configData.Logging.LevelClient,
RegisterEnable: configData.Register.Enable, LoggingServer: configData.Logging.LevelServer,
RegisterDomains: configData.Register.Domains, RegisterEnable: configData.Register.Enable,
Extensions: extensions, RegisterDomains: configData.Register.Domains,
ExtensionsServer: extensionsServer,
ExtensionsClient: extensionsClient,
} }
go statesaveWorker.Start() go statesaveWorker.Start()
@ -122,13 +123,14 @@ func quit() {
func reload() { func reload() {
log.Info("start reloading...") log.Info("start reloading...")
var configNewData *config.Config var configNewData *config.Config
err := file.ReadTOML(configPath, configNewData)
if err != nil { if err := file.ReadTOML(configPath, configNewData); err != nil {
log.Warn("unable to load config file:", err) log.Warn("unable to load config file:", err)
return return
} }
log.SetLevel(configNewData.Logging.Level) log.SetLevel(configNewData.Logging.Level)
srv.LoggingClient = configNewData.Logging.LevelClient srv.LoggingClient = configNewData.Logging.LevelClient
srv.LoggingServer = configNewData.Logging.LevelServer
srv.RegisterEnable = configNewData.Register.Enable srv.RegisterEnable = configNewData.Register.Enable
srv.RegisterDomains = configNewData.Register.Domains srv.RegisterDomains = configNewData.Register.Domains
@ -157,14 +159,15 @@ func reload() {
} }
if restartServer { if restartServer {
newServer := &server.Server{ newServer := &server.Server{
TLSConfig: certs, TLSConfig: certs,
Database: db, Database: db,
ClientAddr: configNewData.Address.Client, ClientAddr: configNewData.Address.Client,
ServerAddr: configNewData.Address.Server, ServerAddr: configNewData.Address.Server,
LoggingClient: configNewData.Logging.LevelClient, LoggingClient: configNewData.Logging.LevelClient,
RegisterEnable: configNewData.Register.Enable, RegisterEnable: configNewData.Register.Enable,
RegisterDomains: configNewData.Register.Domains, RegisterDomains: configNewData.Register.Domains,
Extensions: extensions, ExtensionsServer: extensionsServer,
ExtensionsClient: extensionsClient,
} }
log.Warn("reloading need a restart:") log.Warn("reloading need a restart:")
go newServer.Start() go newServer.Start()
@ -178,7 +181,7 @@ func reload() {
} }
func init() { func init() {
extensions = append(extensions, extensionsClient = append(extensionsClient,
&extension.Message{}, &extension.Message{},
&extension.Presence{}, &extension.Presence{},
extension.IQExtensions{ extension.IQExtensions{
@ -188,10 +191,15 @@ func init() {
&extension.IQDisco{Database: db}, &extension.IQDisco{Database: db},
&extension.IQRoster{Database: db}, &extension.IQRoster{Database: db},
&extension.IQExtensionDiscovery{GetSpaces: func() []string { &extension.IQExtensionDiscovery{GetSpaces: func() []string {
return extensions.Spaces() return extensionsClient.Spaces()
}}, }},
}) })
extensionsServer = append(extensionsServer,
extension.IQExtensions{
&extension.IQPing{},
})
RootCmd.AddCommand(serverCmd) RootCmd.AddCommand(serverCmd)
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")

View File

@ -2,8 +2,9 @@ tlsdir = "tmp/ssl"
state_path = "tmp/yaja.json" state_path = "tmp/yaja.json"
[logging] [logging]
level = 3 level = 5
level_client = 6 level_client = 6
level_server = 6
[register] [register]
enable = true enable = true

View File

@ -10,6 +10,7 @@ type Config struct {
Logging struct { Logging struct {
Level log.Level `toml:"level"` Level log.Level `toml:"level"`
LevelClient log.Level `toml:"level_client"` LevelClient log.Level `toml:"level_client"`
LevelServer log.Level `toml:"level_server"`
} `toml:"logging"` } `toml:"logging"`
Register struct { Register struct {
Enable bool `toml:"enable"` Enable bool `toml:"enable"`

View File

@ -24,8 +24,7 @@ func (ex *IQDisco) Get(msg *messages.IQ, client *utils.Client) bool {
Body []byte `xml:",innerxml"` Body []byte `xml:",innerxml"`
} }
q := &query{} q := &query{}
err := xml.Unmarshal(msg.Body, q) if err := xml.Unmarshal(msg.Body, q); err != nil {
if err != nil {
return false return false
} }

View File

@ -23,8 +23,7 @@ func (ex *IQExtensionDiscovery) Get(msg *messages.IQ, client *utils.Client) bool
Body []byte `xml:",innerxml"` Body []byte `xml:",innerxml"`
} }
q := &query{} q := &query{}
err := xml.Unmarshal(msg.Body, q) if err := xml.Unmarshal(msg.Body, q); err != nil {
if err != nil {
return false return false
} }

View File

@ -25,8 +25,7 @@ func (ex *IQLast) Get(msg *messages.IQ, client *utils.Client) bool {
Body []byte `xml:",innerxml"` Body []byte `xml:",innerxml"`
} }
q := &query{} q := &query{}
err := xml.Unmarshal(msg.Body, q) if err := xml.Unmarshal(msg.Body, q); err != nil {
if err != nil {
return false return false
} }

View File

@ -21,8 +21,7 @@ func (ex *IQPing) Get(msg *messages.IQ, client *utils.Client) bool {
XMLName xml.Name `xml:"urn:xmpp:ping ping"` XMLName xml.Name `xml:"urn:xmpp:ping ping"`
} }
pq := &ping{} pq := &ping{}
err := xml.Unmarshal(msg.Body, pq) if err := xml.Unmarshal(msg.Body, pq); err != nil {
if err != nil {
return false return false
} }

View File

@ -27,8 +27,7 @@ func (ex *IQPrivate) Get(msg *messages.IQ, client *utils.Client) bool {
// query encode // query encode
q := &iqPrivateQuery{} q := &iqPrivateQuery{}
err := xml.Unmarshal(msg.Body, q) if err := xml.Unmarshal(msg.Body, q); err != nil {
if err != nil {
return false return false
} }

View File

@ -19,8 +19,7 @@ func (ex *IQPrivateBookmark) Handle(msg *messages.IQ, q *iqPrivateQuery, client
XMLName xml.Name `xml:"storage:bookmarks storage"` XMLName xml.Name `xml:"storage:bookmarks storage"`
} }
s := &storage{} s := &storage{}
err := xml.Unmarshal(q.Body, s) if err := xml.Unmarshal(q.Body, s); err != nil {
if err != nil {
return false return false
} }
/* /*

View File

@ -19,8 +19,7 @@ func (ex *IQPrivateMetacontact) Handle(msg *messages.IQ, q *iqPrivateQuery, clie
XMLName xml.Name `xml:"storage:metacontacts storage"` XMLName xml.Name `xml:"storage:metacontacts storage"`
} }
s := &storage{} s := &storage{}
err := xml.Unmarshal(q.Body, s) if err := xml.Unmarshal(q.Body, s); err != nil {
if err != nil {
return false return false
} }
/* /*

View File

@ -20,8 +20,7 @@ func (ex *IQPrivateRoster) Handle(msg *messages.IQ, q *iqPrivateQuery, client *u
Body []byte `xml:",innerxml"` Body []byte `xml:",innerxml"`
} }
r := &roster{} r := &roster{}
err := xml.Unmarshal(q.Body, r) if err := xml.Unmarshal(q.Body, r); err != nil {
if err != nil {
return false return false
} }

View File

@ -25,8 +25,7 @@ func (ex *IQRoster) Get(msg *messages.IQ, client *utils.Client) bool {
Body []byte `xml:",innerxml"` Body []byte `xml:",innerxml"`
} }
q := &query{} q := &query{}
err := xml.Unmarshal(msg.Body, q) if err := xml.Unmarshal(msg.Body, q); err != nil {
if err != nil {
return false return false
} }

View File

@ -8,21 +8,24 @@ import (
"github.com/genofire/yaja/model" "github.com/genofire/yaja/model"
"github.com/genofire/yaja/server/extension" "github.com/genofire/yaja/server/extension"
"github.com/genofire/yaja/server/toclient" "github.com/genofire/yaja/server/toclient"
"github.com/genofire/yaja/server/toserver"
"github.com/genofire/yaja/server/utils" "github.com/genofire/yaja/server/utils"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
) )
type Server struct { type Server struct {
TLSConfig *tls.Config TLSConfig *tls.Config
TLSManager *autocert.Manager TLSManager *autocert.Manager
ClientAddr []string ClientAddr []string
ServerAddr []string ServerAddr []string
Database *database.State Database *database.State
LoggingClient log.Level LoggingClient log.Level
RegisterEnable bool LoggingServer log.Level
RegisterDomains []string RegisterEnable bool
Extensions extension.Extensions RegisterDomains []string
ExtensionsClient extension.Extensions
ExtensionsServer extension.Extensions
} }
func (srv *Server) Start() { func (srv *Server) Start() {
@ -69,15 +72,33 @@ func (srv *Server) listenClient(c2s net.Listener) {
func (srv *Server) handleServer(conn net.Conn) { func (srv *Server) handleServer(conn net.Conn) {
log.Info("new server connection:", conn.RemoteAddr()) log.Info("new server connection:", conn.RemoteAddr())
client := utils.NewClient(conn, srv.LoggingClient)
client.Log = client.Log.WithField("c", "s2s")
state := toserver.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.ExtensionsServer, client)
for {
state = state.Process()
if state == nil {
client.Log.Info("disconnect")
client.Close()
return
}
// run next state
}
} }
func (srv *Server) handleClient(conn net.Conn) { func (srv *Server) handleClient(conn net.Conn) {
log.Info("new client connection:", conn.RemoteAddr()) log.Info("new client connection:", conn.RemoteAddr())
client := utils.NewClient(conn, srv.LoggingClient)
state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed, srv.Extensions) client := utils.NewClient(conn, srv.LoggingServer)
client.Log = client.Log.WithField("c", "c2s")
state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed, srv.ExtensionsClient, client)
for { for {
state, client = state.Process(client) state = state.Process()
if state == nil { if state == nil {
client.Log.Info("disconnect") client.Log.Info("disconnect")
client.Close() client.Close()

View File

@ -10,116 +10,107 @@ import (
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
) )
// ConnectionStartup return steps through TCP TLS state
func ConnectionStartup(after State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager) State {
tlsupgrade := &TLSUpgrade{
Next: after,
tlsconfig: tlsconfig,
tlsmgmt: tlsmgmt,
}
stream := &Start{Next: tlsupgrade}
return stream
}
// Start state // Start state
type Start struct { type Start struct {
Next State Next State
Client *utils.Client
} }
// Process message // Process message
func (state *Start) Process(client *utils.Client) (State, *utils.Client) { func (state *Start) Process() State {
client.Log = client.Log.WithField("state", "stream") state.Client.Log = state.Client.Log.WithField("state", "stream")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
element, err := client.Read() element, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.Log.Warn("is no stream") state.Client.Log.Warn("is no stream")
return state, client return state
} }
for _, attr := range element.Attr { for _, attr := range element.Attr {
if attr.Name.Local == "to" { if attr.Name.Local == "to" {
client.JID = &model.JID{Domain: attr.Value} state.Client.JID = &model.JID{Domain: attr.Value}
client.Log = client.Log.WithField("jid", client.JID.Full()) state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
} }
} }
if client.JID == nil { if state.Client.JID == nil {
client.Log.Warn("no 'to' domain readed") state.Client.Log.Warn("no 'to' domain readed")
return nil, client return nil
} }
fmt.Fprintf(client.Conn, `<?xml version='1.0'?> fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`, <stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream) utils.CreateCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features> fmt.Fprintf(state.Client.Conn, `<stream:features>
<starttls xmlns='%s'> <starttls xmlns='%s'>
<required/> <required/>
</starttls> </starttls>
</stream:features>`, </stream:features>`,
messages.NSStream) messages.NSStream)
return state.Next, client return state.Next
} }
// TLSUpgrade state // TLSUpgrade state
type TLSUpgrade struct { type TLSUpgrade struct {
Next State Next State
tlsconfig *tls.Config Client *utils.Client
tlsmgmt *autocert.Manager TLSConfig *tls.Config
TLSManager *autocert.Manager
} }
// Process message // Process message
func (state *TLSUpgrade) Process(client *utils.Client) (State, *utils.Client) { func (state *TLSUpgrade) Process() State {
client.Log = client.Log.WithField("state", "tls upgrade") state.Client.Log = state.Client.Log.WithField("state", "tls upgrade")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
element, err := client.Read() element, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" { if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
client.Log.Warn("is no starttls") state.Client.Log.Warn("is no starttls", element)
return state, client return nil
} }
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS) fmt.Fprintf(state.Client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
// perform the TLS handshake // perform the TLS handshake
var tlsConfig *tls.Config var tlsConfig *tls.Config
if m := state.tlsmgmt; m != nil { if m := state.TLSManager; m != nil {
var cert *tls.Certificate var cert *tls.Certificate
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.JID.Domain}) cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: state.Client.JID.Domain})
if err != nil { if err != nil {
client.Log.Warn("no cert in tls manger found: ", err) state.Client.Log.Warn("no cert in tls manger found: ", err)
return nil, client return nil
} }
tlsConfig = &tls.Config{ tlsConfig = &tls.Config{
Certificates: []tls.Certificate{*cert}, Certificates: []tls.Certificate{*cert},
} }
} }
if tlsConfig == nil { if tlsConfig == nil {
tlsConfig = state.tlsconfig tlsConfig = state.TLSConfig
if tlsConfig != nil { if tlsConfig != nil {
tlsConfig.ServerName = client.JID.Domain tlsConfig.ServerName = state.Client.JID.Domain
} else { } else {
client.Log.Warn("no tls config found: ", err) state.Client.Log.Warn("no tls config found: ", err)
return nil, client return nil
} }
} }
tlsConn := tls.Server(client.Conn, tlsConfig) tlsConn := tls.Server(state.Client.Conn, tlsConfig)
err = tlsConn.Handshake() err = tlsConn.Handshake()
if err != nil { if err != nil {
client.Log.Warn("unable to tls handshake: ", err) state.Client.Log.Warn("unable to tls handshake: ", err)
return nil, client return nil
} }
// restart the Connection // restart the Connection
client.SetConnecting(tlsConn) state.Client.SetConnecting(tlsConn)
return state.Next, client return state.Next
} }

49
server/state/normal.go Normal file
View File

@ -0,0 +1,49 @@
package state
import (
"github.com/genofire/yaja/server/extension"
"github.com/genofire/yaja/server/utils"
)
// SendingClient state
type SendingClient struct {
Next State
Client *utils.Client
}
// Process messages
func (state *SendingClient) Process() State {
state.Client.Log = state.Client.Log.WithField("state", "normal")
state.Client.Log.Debug("sending")
// sending
go func() {
select {
case msg := <-state.Client.Messages:
err := state.Client.Out.Encode(msg)
if err != nil {
state.Client.Log.Warn(err)
}
case <-state.Client.OnClose():
return
}
}()
state.Client.Log.Debug("receiving")
return state.Next
}
// ReceivingClient state
type ReceivingClient struct {
Extensions extension.Extensions
Client *utils.Client
}
// Process messages
func (state *ReceivingClient) Process() State {
element, err := state.Client.Read()
if err != nil {
state.Client.Log.Warn("unable to read: ", err)
return nil
}
state.Extensions.Process(element, state.Client)
return state
}

View File

@ -4,5 +4,27 @@ import "github.com/genofire/yaja/server/utils"
// State processes the stream and moves to the next state // State processes the stream and moves to the next state
type State interface { type State interface {
Process(client *utils.Client) (State, *utils.Client) Process() State
}
// Start state
type Debug struct {
Next State
Client *utils.Client
}
// Process message
func (state *Debug) Process() State {
state.Client.Log = state.Client.Log.WithField("state", "debug")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
element, err := state.Client.Read()
if err != nil {
state.Client.Log.Warn("unable to read: ", err)
return nil
}
state.Client.Log.Info(element)
return state.Next
} }

View File

@ -16,51 +16,60 @@ import (
) )
// ConnectionStartup return steps through TCP TLS state // ConnectionStartup return steps through TCP TLS state
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions []extension.Extension) state.State { func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions extension.Extensions, c *utils.Client) state.State {
receiving := &ReceivingClient{Extensions: extensions} receiving := &state.ReceivingClient{Extensions: extensions, Client: c}
sending := &SendingClient{Next: receiving} sending := &state.SendingClient{Next: receiving, Client: c}
authedstream := &AuthedStream{Next: sending} authedstream := &AuthedStream{Next: sending, Client: c}
authedstart := &AuthedStart{Next: authedstream} authedstart := &AuthedStart{Next: authedstream, Client: c}
tlsauth := &SASLAuth{ tlsauth := &SASLAuth{
Next: authedstart, Next: authedstart,
Client: c,
database: db, database: db,
domainRegisterAllowed: registerAllowed, domainRegisterAllowed: registerAllowed,
} }
tlsstream := &TLSStream{ tlsstream := &TLSStream{
Next: tlsauth, Next: tlsauth,
Client: c,
domainRegisterAllowed: registerAllowed, domainRegisterAllowed: registerAllowed,
} }
return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt) tlsupgrade := &state.TLSUpgrade{
Next: tlsstream,
Client: c,
TLSConfig: tlsconfig,
TLSManager: tlsmgmt,
}
return &state.Start{Next: tlsupgrade, Client: c}
} }
// TLSStream state // TLSStream state
type TLSStream struct { type TLSStream struct {
Next state.State Next state.State
Client *utils.Client
domainRegisterAllowed utils.DomainRegisterAllowed domainRegisterAllowed utils.DomainRegisterAllowed
} }
// Process messages // Process messages
func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) { func (state *TLSStream) Process() state.State {
client.Log = client.Log.WithField("state", "tls stream") state.Client.Log = state.Client.Log.WithField("state", "tls stream")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
element, err := client.Read() element, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.Log.Warn("is no stream") state.Client.Log.Warn("is no stream")
return state, client return state
} }
fmt.Fprintf(client.Conn, `<?xml version='1.0'?> fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`, <stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream) utils.CreateCookie(), messages.NSClient, messages.NSStream)
if state.domainRegisterAllowed(client.JID) { if state.domainRegisterAllowed(state.Client.JID) {
fmt.Fprintf(client.Conn, `<stream:features> fmt.Fprintf(state.Client.Conn, `<stream:features>
<mechanisms xmlns='%s'> <mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism> <mechanism>PLAIN</mechanism>
</mechanisms> </mechanisms>
@ -68,7 +77,7 @@ func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Clien
</stream:features>`, </stream:features>`,
messages.NSSASL, messages.NSFeaturesIQRegister) messages.NSSASL, messages.NSFeaturesIQRegister)
} else { } else {
fmt.Fprintf(client.Conn, `<stream:features> fmt.Fprintf(state.Client.Conn, `<stream:features>
<mechanisms xmlns='%s'> <mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism> <mechanism>PLAIN</mechanism>
</mechanisms> </mechanisms>
@ -76,124 +85,129 @@ func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Clien
messages.NSSASL) messages.NSSASL)
} }
return state.Next, client return state.Next
} }
// SASLAuth state // SASLAuth state
type SASLAuth struct { type SASLAuth struct {
Next state.State Next state.State
Client *utils.Client
database *database.State database *database.State
domainRegisterAllowed utils.DomainRegisterAllowed domainRegisterAllowed utils.DomainRegisterAllowed
} }
// Process messages // Process messages
func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) { func (state *SASLAuth) Process() state.State {
client.Log = client.Log.WithField("state", "sasl auth") state.Client.Log = state.Client.Log.WithField("state", "sasl auth")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
// read the full auth stanza // read the full auth stanza
element, err := client.Read() element, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
var auth messages.SASLAuth var auth messages.SASLAuth
if err = client.In.DecodeElement(&auth, element); err != nil { if err = state.Client.In.DecodeElement(&auth, element); err != nil {
client.Log.Info("start substate for registration") state.Client.Log.Info("start substate for registration")
return &RegisterFormRequest{ return &RegisterFormRequest{
Next: &RegisterRequest{
Next: state.Next,
Client: state.Client,
database: state.database,
domainRegisterAllowed: state.domainRegisterAllowed,
},
Client: state.Client,
element: element, element: element,
domainRegisterAllowed: state.domainRegisterAllowed, domainRegisterAllowed: state.domainRegisterAllowed,
Next: &RegisterRequest{ }
domainRegisterAllowed: state.domainRegisterAllowed,
database: state.database,
Next: state.Next,
},
}, client
} }
data, err := base64.StdEncoding.DecodeString(auth.Body) data, err := base64.StdEncoding.DecodeString(auth.Body)
if err != nil { if err != nil {
client.Log.Warn("body decode: ", err) state.Client.Log.Warn("body decode: ", err)
return nil, client return nil
} }
info := strings.Split(string(data), "\x00") info := strings.Split(string(data), "\x00")
// should check that info[1] starts with client.JID // should check that info[1] starts with state.Client.JID
client.JID.Local = info[1] state.Client.JID.Local = info[1]
client.Log = client.Log.WithField("jid", client.JID.Full()) state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
success, err := state.database.Authenticate(client.JID, info[2]) success, err := state.database.Authenticate(state.Client.JID, info[2])
if err != nil { if err != nil {
client.Log.Warn("auth: ", err) state.Client.Log.Warn("auth: ", err)
return nil, client return nil
} }
if success { if success {
client.Log.Info("success auth") state.Client.Log.Info("success auth")
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL) fmt.Fprintf(state.Client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
return state.Next, client return state.Next
} }
client.Log.Warn("failed auth") state.Client.Log.Warn("failed auth")
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL) fmt.Fprintf(state.Client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
return nil, client return nil
} }
// AuthedStart state // AuthedStart state
type AuthedStart struct { type AuthedStart struct {
Next state.State Next state.State
Client *utils.Client
} }
// Process messages // Process messages
func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) { func (state *AuthedStart) Process() state.State {
client.Log = client.Log.WithField("state", "authed started") state.Client.Log = state.Client.Log.WithField("state", "authed started")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
_, err := client.Read() _, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
fmt.Fprintf(client.Conn, `<?xml version='1.0'?> fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`, <stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream) utils.CreateCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features> fmt.Fprintf(state.Client.Conn, `<stream:features>
<bind xmlns='%s'/> <bind xmlns='%s'/>
</stream:features>`, </stream:features>`,
messages.NSBind) messages.NSBind)
return state.Next, client return state.Next
} }
// AuthedStream state // AuthedStream state
type AuthedStream struct { type AuthedStream struct {
Next state.State Next state.State
Client *utils.Client
} }
// Process messages // Process messages
func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) { func (state *AuthedStream) Process() state.State {
client.Log = client.Log.WithField("state", "authed stream") state.Client.Log = state.Client.Log.WithField("state", "authed stream")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
// check that it's a bind request // check that it's a bind request
// read bind request // read bind request
element, err := client.Read() element, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
var msg messages.IQ var msg messages.IQ
if err = client.In.DecodeElement(&msg, element); err != nil { if err = state.Client.In.DecodeElement(&msg, element); err != nil {
client.Log.Warn("is no iq: ", err) state.Client.Log.Warn("is no iq: ", err)
return nil, client return nil
} }
if msg.Type != messages.IQTypeSet { if msg.Type != messages.IQTypeSet {
client.Log.Warn("is no set iq") state.Client.Log.Warn("is no set iq")
return nil, client return nil
} }
if msg.Error != nil { if msg.Error != nil {
client.Log.Warn("iq with error: ", msg.Error.Code) state.Client.Log.Warn("iq with error: ", msg.Error.Code)
return nil, client return nil
} }
type query struct { type query struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
@ -202,26 +216,26 @@ func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Cl
q := &query{} q := &query{}
err = xml.Unmarshal(msg.Body, q) err = xml.Unmarshal(msg.Body, q)
if err != nil { if err != nil {
client.Log.Warn("is no iq bind: ", err) state.Client.Log.Warn("is no iq bind: ", err)
return nil, client return nil
} }
if q.Resource == "" { if q.Resource == "" {
client.JID.Resource = makeResource() state.Client.JID.Resource = makeResource()
} else { } else {
client.JID.Resource = q.Resource state.Client.JID.Resource = q.Resource
} }
client.Log = client.Log.WithField("jid", client.JID.Full()) state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
client.Out.Encode(&messages.IQ{ state.Client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
To: client.JID.String(), To: state.Client.JID.String(),
From: client.JID.Domain, From: state.Client.JID.Domain,
ID: msg.ID, ID: msg.ID,
Body: []byte(fmt.Sprintf( Body: []byte(fmt.Sprintf(
`<bind xmlns='%s'> `<bind xmlns='%s'>
<jid>%s</jid> <jid>%s</jid>
</bind>`, </bind>`,
messages.NSBind, client.JID.Full())), messages.NSBind, state.Client.JID.Full())),
}) })
return state.Next, client return state.Next
} }

View File

@ -1,48 +0,0 @@
package toclient
import (
"github.com/genofire/yaja/server/extension"
"github.com/genofire/yaja/server/state"
"github.com/genofire/yaja/server/utils"
)
// SendingClient state
type SendingClient struct {
Next state.State
}
// Process messages
func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) {
client.Log = client.Log.WithField("state", "normal")
client.Log.Debug("sending")
// sending
go func() {
select {
case msg := <-client.Messages:
err := client.Out.Encode(msg)
if err != nil {
client.Log.Warn(err)
}
case <-client.OnClose():
return
}
}()
client.Log.Debug("receiving")
return state.Next, client
}
// ReceivingClient state
type ReceivingClient struct {
Extensions extension.Extensions
}
// Process messages
func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) {
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
state.Extensions.Process(element, client)
return state, client
}

View File

@ -13,33 +13,34 @@ import (
type RegisterFormRequest struct { type RegisterFormRequest struct {
Next state.State Next state.State
Client *utils.Client
domainRegisterAllowed utils.DomainRegisterAllowed domainRegisterAllowed utils.DomainRegisterAllowed
element *xml.StartElement element *xml.StartElement
} }
// Process message // Process message
func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *utils.Client) { func (state *RegisterFormRequest) Process() state.State {
client.Log = client.Log.WithField("state", "register form request") state.Client.Log = state.Client.Log.WithField("state", "register form request")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
if !state.domainRegisterAllowed(client.JID) { if !state.domainRegisterAllowed(state.Client.JID) {
client.Log.Error("unpossible to reach this state, register on this domain is not allowed") state.Client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
return nil, client return nil
} }
var msg messages.IQ var msg messages.IQ
if err := client.In.DecodeElement(&msg, state.element); err != nil { if err := state.Client.In.DecodeElement(&msg, state.element); err != nil {
client.Log.Warn("is no iq: ", err) state.Client.Log.Warn("is no iq: ", err)
return state, client return state
} }
if msg.Type != messages.IQTypeGet { if msg.Type != messages.IQTypeGet {
client.Log.Warn("is no get iq") state.Client.Log.Warn("is no get iq")
return state, client return state
} }
if msg.Error != nil { if msg.Error != nil {
client.Log.Warn("iq with error: ", msg.Error.Code) state.Client.Log.Warn("iq with error: ", msg.Error.Code)
return state, client return state
} }
type query struct { type query struct {
XMLName xml.Name `xml:"query"` XMLName xml.Name `xml:"query"`
@ -48,13 +49,13 @@ func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *u
err := xml.Unmarshal(msg.Body, q) err := xml.Unmarshal(msg.Body, q)
if q.XMLName.Space != messages.NSIQRegister || err != nil { if q.XMLName.Space != messages.NSIQRegister || err != nil {
client.Log.Warn("is no iq register: ", err) state.Client.Log.Warn("is no iq register: ", err)
return nil, client return nil
} }
client.Out.Encode(&messages.IQ{ state.Client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
To: client.JID.String(), To: state.Client.JID.String(),
From: client.JID.Domain, From: state.Client.JID.Domain,
ID: msg.ID, ID: msg.ID,
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions> Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
Choose a username and password for use with this service. Choose a username and password for use with this service.
@ -63,43 +64,44 @@ func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *u
<password/> <password/>
</query>`, messages.NSIQRegister)), </query>`, messages.NSIQRegister)),
}) })
return state.Next, client return state.Next
} }
type RegisterRequest struct { type RegisterRequest struct {
Next state.State Next state.State
Client *utils.Client
database *database.State database *database.State
domainRegisterAllowed utils.DomainRegisterAllowed domainRegisterAllowed utils.DomainRegisterAllowed
} }
// Process message // Process message
func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils.Client) { func (state *RegisterRequest) Process() state.State {
client.Log = client.Log.WithField("state", "register request") state.Client.Log = state.Client.Log.WithField("state", "register request")
client.Log.Debug("running") state.Client.Log.Debug("running")
defer client.Log.Debug("leave") defer state.Client.Log.Debug("leave")
if !state.domainRegisterAllowed(client.JID) { if !state.domainRegisterAllowed(state.Client.JID) {
client.Log.Error("unpossible to reach this state, register on this domain is not allowed") state.Client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
return nil, client return nil
} }
element, err := client.Read() element, err := state.Client.Read()
if err != nil { if err != nil {
client.Log.Warn("unable to read: ", err) state.Client.Log.Warn("unable to read: ", err)
return nil, client return nil
} }
var msg messages.IQ var msg messages.IQ
if err = client.In.DecodeElement(&msg, element); err != nil { if err = state.Client.In.DecodeElement(&msg, element); err != nil {
client.Log.Warn("is no iq: ", err) state.Client.Log.Warn("is no iq: ", err)
return state, client return state
} }
if msg.Type != messages.IQTypeGet { if msg.Type != messages.IQTypeGet {
client.Log.Warn("is no get iq") state.Client.Log.Warn("is no get iq")
return state, client return state
} }
if msg.Error != nil { if msg.Error != nil {
client.Log.Warn("iq with error: ", msg.Error.Code) state.Client.Log.Warn("iq with error: ", msg.Error.Code)
return state, client return state
} }
type query struct { type query struct {
XMLName xml.Name `xml:"query"` XMLName xml.Name `xml:"query"`
@ -109,19 +111,19 @@ func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils
q := &query{} q := &query{}
err = xml.Unmarshal(msg.Body, q) err = xml.Unmarshal(msg.Body, q)
if err != nil { if err != nil {
client.Log.Warn("is no iq register: ", err) state.Client.Log.Warn("is no iq register: ", err)
return nil, client return nil
} }
client.JID.Local = q.Username state.Client.JID.Local = q.Username
client.Log = client.Log.WithField("jid", client.JID.Full()) state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
account := model.NewAccount(client.JID, q.Password) account := model.NewAccount(state.Client.JID, q.Password)
err = state.database.AddAccount(account) err = state.database.AddAccount(account)
if err != nil { if err != nil {
client.Out.Encode(&messages.IQ{ state.Client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
To: client.JID.String(), To: state.Client.JID.String(),
From: client.JID.Domain, From: state.Client.JID.Domain,
ID: msg.ID, ID: msg.ID,
Body: []byte(fmt.Sprintf(`<query xmlns='%s'> Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
<username>%s</username> <username>%s</username>
@ -136,16 +138,16 @@ func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils
}, },
}, },
}) })
client.Log.Warn("database error: ", err) state.Client.Log.Warn("database error: ", err)
return state, client return state
} }
client.Out.Encode(&messages.IQ{ state.Client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
To: client.JID.String(), To: state.Client.JID.String(),
From: client.JID.Domain, From: state.Client.JID.Domain,
ID: msg.ID, ID: msg.ID,
}) })
client.Log.Infof("registered client %s", client.JID.Bare()) state.Client.Log.Infof("registered client %s", state.Client.JID.Bare())
return state.Next, client return state.Next
} }

141
server/toserver/connect.go Normal file
View File

@ -0,0 +1,141 @@
package toserver
import (
"crypto/tls"
"encoding/base64"
"encoding/xml"
"fmt"
"github.com/genofire/yaja/database"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/server/extension"
"github.com/genofire/yaja/server/state"
"github.com/genofire/yaja/server/utils"
"golang.org/x/crypto/acme/autocert"
)
// ConnectionStartup return steps through TCP TLS state
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, extensions extension.Extensions, c *utils.Client) state.State {
receiving := &state.ReceivingClient{Extensions: extensions, Client: c}
sending := &state.SendingClient{Next: receiving, Client: c}
tlsstream := &TLSStream{
Next: sending,
Client: c,
}
tlsupgrade := &state.TLSUpgrade{
Next: tlsstream,
Client: c,
TLSConfig: tlsconfig,
TLSManager: tlsmgmt,
}
dail := &Dailback{
Next: tlsupgrade,
Client: c,
}
return &state.Start{Next: dail, Client: c}
}
// TLSStream state
type Dailback struct {
Next state.State
Client *utils.Client
}
// Process messages
func (state *Dailback) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "dialback")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
element, err := state.Client.Read()
if err != nil {
state.Client.Log.Warn("unable to read: ", err)
return nil
}
// dailback encode
type dailback struct {
XMLName xml.Name `xml:"urn:xmpp:ping ping"`
}
db := &dailback{}
if err = state.Client.In.DecodeElement(db, element); err != nil {
return state.Next
}
state.Client.Log.Info(db)
return state.Next
}
// TLSStream state
type TLSStream struct {
Next state.State
Client *utils.Client
domainRegisterAllowed utils.DomainRegisterAllowed
}
// Process messages
func (state *TLSStream) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "tls stream")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
element, err := state.Client.Read()
if err != nil {
state.Client.Log.Warn("unable to read: ", err)
return nil
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
state.Client.Log.Warn("is no stream")
return state
}
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(state.Client.Conn, `<stream:features>
<mechanisms xmlns='%s'>
<mechanism>EXTERNAL</mechanism>
</mechanisms>
</stream:features>`,
messages.NSSASL)
return state.Next
}
// SASLAuth state
type SASLAuth struct {
Next state.State
Client *utils.Client
database *database.State
domainRegisterAllowed utils.DomainRegisterAllowed
}
// Process messages
func (state *SASLAuth) Process() state.State {
state.Client.Log = state.Client.Log.WithField("state", "sasl auth")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
// read the full auth stanza
element, err := state.Client.Read()
if err != nil {
state.Client.Log.Warn("unable to read: ", err)
return nil
}
var auth messages.SASLAuth
if err = state.Client.In.DecodeElement(&auth, element); err != nil {
return nil
}
data, err := base64.StdEncoding.DecodeString(auth.Body)
if err != nil {
state.Client.Log.Warn("body decode: ", err)
return nil
}
state.Client.Log.Debug(auth.Mechanism, string(data))
state.Client.Log.Info("success auth")
fmt.Fprintf(state.Client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
return state.Next
}

View File

@ -30,7 +30,7 @@ func NewClient(conn net.Conn, level log.Level) *Client {
Log: log.NewEntry(logger), Log: log.NewEntry(logger),
In: xml.NewDecoder(conn), In: xml.NewDecoder(conn),
Out: xml.NewEncoder(conn), Out: xml.NewEncoder(conn),
Messages: make(chan interface{}, 1000), Messages: make(chan interface{}),
close: make(chan interface{}), close: make(chan interface{}),
} }
return client return client