move client to state attr + s2s idea
This commit is contained in:
parent
1e2e578076
commit
e474f460aa
|
@ -25,12 +25,13 @@ import (
|
||||||
var configPath string
|
var configPath string
|
||||||
|
|
||||||
var (
|
var (
|
||||||
configData = &config.Config{}
|
configData = &config.Config{}
|
||||||
db = &database.State{}
|
db = &database.State{}
|
||||||
statesaveWorker *worker.Worker
|
statesaveWorker *worker.Worker
|
||||||
srv *server.Server
|
srv *server.Server
|
||||||
certs *tls.Config
|
certs *tls.Config
|
||||||
extensions extension.Extensions
|
extensionsClient extension.Extensions
|
||||||
|
extensionsServer extension.Extensions
|
||||||
)
|
)
|
||||||
|
|
||||||
// serverCmd represents the serve command
|
// serverCmd represents the serve command
|
||||||
|
@ -39,16 +40,14 @@ var serverCmd = &cobra.Command{
|
||||||
Short: "Runs the yaja server",
|
Short: "Runs the yaja server",
|
||||||
Example: "yaja serve -c /etc/yaja.conf",
|
Example: "yaja serve -c /etc/yaja.conf",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
var err error
|
|
||||||
err = file.ReadTOML(configPath, configData)
|
if err := file.ReadTOML(configPath, configData); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Fatal("unable to load config file:", err)
|
log.Fatal("unable to load config file:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetLevel(configData.Logging.Level)
|
log.SetLevel(configData.Logging.Level)
|
||||||
|
|
||||||
err = file.ReadJSON(configData.StatePath, db)
|
if err := file.ReadJSON(configData.StatePath, db); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Warn("unable to load state file:", err)
|
log.Warn("unable to load state file:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,14 +75,16 @@ var serverCmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
srv = &server.Server{
|
srv = &server.Server{
|
||||||
TLSManager: &m,
|
TLSManager: &m,
|
||||||
Database: db,
|
Database: db,
|
||||||
ClientAddr: configData.Address.Client,
|
ClientAddr: configData.Address.Client,
|
||||||
ServerAddr: configData.Address.Server,
|
ServerAddr: configData.Address.Server,
|
||||||
LoggingClient: configData.Logging.LevelClient,
|
LoggingClient: configData.Logging.LevelClient,
|
||||||
RegisterEnable: configData.Register.Enable,
|
LoggingServer: configData.Logging.LevelServer,
|
||||||
RegisterDomains: configData.Register.Domains,
|
RegisterEnable: configData.Register.Enable,
|
||||||
Extensions: extensions,
|
RegisterDomains: configData.Register.Domains,
|
||||||
|
ExtensionsServer: extensionsServer,
|
||||||
|
ExtensionsClient: extensionsClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
go statesaveWorker.Start()
|
go statesaveWorker.Start()
|
||||||
|
@ -122,13 +123,14 @@ func quit() {
|
||||||
func reload() {
|
func reload() {
|
||||||
log.Info("start reloading...")
|
log.Info("start reloading...")
|
||||||
var configNewData *config.Config
|
var configNewData *config.Config
|
||||||
err := file.ReadTOML(configPath, configNewData)
|
|
||||||
if err != nil {
|
if err := file.ReadTOML(configPath, configNewData); err != nil {
|
||||||
log.Warn("unable to load config file:", err)
|
log.Warn("unable to load config file:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.SetLevel(configNewData.Logging.Level)
|
log.SetLevel(configNewData.Logging.Level)
|
||||||
srv.LoggingClient = configNewData.Logging.LevelClient
|
srv.LoggingClient = configNewData.Logging.LevelClient
|
||||||
|
srv.LoggingServer = configNewData.Logging.LevelServer
|
||||||
srv.RegisterEnable = configNewData.Register.Enable
|
srv.RegisterEnable = configNewData.Register.Enable
|
||||||
srv.RegisterDomains = configNewData.Register.Domains
|
srv.RegisterDomains = configNewData.Register.Domains
|
||||||
|
|
||||||
|
@ -157,14 +159,15 @@ func reload() {
|
||||||
}
|
}
|
||||||
if restartServer {
|
if restartServer {
|
||||||
newServer := &server.Server{
|
newServer := &server.Server{
|
||||||
TLSConfig: certs,
|
TLSConfig: certs,
|
||||||
Database: db,
|
Database: db,
|
||||||
ClientAddr: configNewData.Address.Client,
|
ClientAddr: configNewData.Address.Client,
|
||||||
ServerAddr: configNewData.Address.Server,
|
ServerAddr: configNewData.Address.Server,
|
||||||
LoggingClient: configNewData.Logging.LevelClient,
|
LoggingClient: configNewData.Logging.LevelClient,
|
||||||
RegisterEnable: configNewData.Register.Enable,
|
RegisterEnable: configNewData.Register.Enable,
|
||||||
RegisterDomains: configNewData.Register.Domains,
|
RegisterDomains: configNewData.Register.Domains,
|
||||||
Extensions: extensions,
|
ExtensionsServer: extensionsServer,
|
||||||
|
ExtensionsClient: extensionsClient,
|
||||||
}
|
}
|
||||||
log.Warn("reloading need a restart:")
|
log.Warn("reloading need a restart:")
|
||||||
go newServer.Start()
|
go newServer.Start()
|
||||||
|
@ -178,7 +181,7 @@ func reload() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
extensions = append(extensions,
|
extensionsClient = append(extensionsClient,
|
||||||
&extension.Message{},
|
&extension.Message{},
|
||||||
&extension.Presence{},
|
&extension.Presence{},
|
||||||
extension.IQExtensions{
|
extension.IQExtensions{
|
||||||
|
@ -188,10 +191,15 @@ func init() {
|
||||||
&extension.IQDisco{Database: db},
|
&extension.IQDisco{Database: db},
|
||||||
&extension.IQRoster{Database: db},
|
&extension.IQRoster{Database: db},
|
||||||
&extension.IQExtensionDiscovery{GetSpaces: func() []string {
|
&extension.IQExtensionDiscovery{GetSpaces: func() []string {
|
||||||
return extensions.Spaces()
|
return extensionsClient.Spaces()
|
||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
extensionsServer = append(extensionsServer,
|
||||||
|
extension.IQExtensions{
|
||||||
|
&extension.IQPing{},
|
||||||
|
})
|
||||||
|
|
||||||
RootCmd.AddCommand(serverCmd)
|
RootCmd.AddCommand(serverCmd)
|
||||||
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,9 @@ tlsdir = "tmp/ssl"
|
||||||
state_path = "tmp/yaja.json"
|
state_path = "tmp/yaja.json"
|
||||||
|
|
||||||
[logging]
|
[logging]
|
||||||
level = 3
|
level = 5
|
||||||
level_client = 6
|
level_client = 6
|
||||||
|
level_server = 6
|
||||||
|
|
||||||
[register]
|
[register]
|
||||||
enable = true
|
enable = true
|
||||||
|
|
|
@ -10,6 +10,7 @@ type Config struct {
|
||||||
Logging struct {
|
Logging struct {
|
||||||
Level log.Level `toml:"level"`
|
Level log.Level `toml:"level"`
|
||||||
LevelClient log.Level `toml:"level_client"`
|
LevelClient log.Level `toml:"level_client"`
|
||||||
|
LevelServer log.Level `toml:"level_server"`
|
||||||
} `toml:"logging"`
|
} `toml:"logging"`
|
||||||
Register struct {
|
Register struct {
|
||||||
Enable bool `toml:"enable"`
|
Enable bool `toml:"enable"`
|
||||||
|
|
|
@ -24,8 +24,7 @@ func (ex *IQDisco) Get(msg *messages.IQ, client *utils.Client) bool {
|
||||||
Body []byte `xml:",innerxml"`
|
Body []byte `xml:",innerxml"`
|
||||||
}
|
}
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
if err := xml.Unmarshal(msg.Body, q); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,7 @@ func (ex *IQExtensionDiscovery) Get(msg *messages.IQ, client *utils.Client) bool
|
||||||
Body []byte `xml:",innerxml"`
|
Body []byte `xml:",innerxml"`
|
||||||
}
|
}
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
if err := xml.Unmarshal(msg.Body, q); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,7 @@ func (ex *IQLast) Get(msg *messages.IQ, client *utils.Client) bool {
|
||||||
Body []byte `xml:",innerxml"`
|
Body []byte `xml:",innerxml"`
|
||||||
}
|
}
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
if err := xml.Unmarshal(msg.Body, q); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,7 @@ func (ex *IQPing) Get(msg *messages.IQ, client *utils.Client) bool {
|
||||||
XMLName xml.Name `xml:"urn:xmpp:ping ping"`
|
XMLName xml.Name `xml:"urn:xmpp:ping ping"`
|
||||||
}
|
}
|
||||||
pq := &ping{}
|
pq := &ping{}
|
||||||
err := xml.Unmarshal(msg.Body, pq)
|
if err := xml.Unmarshal(msg.Body, pq); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,7 @@ func (ex *IQPrivate) Get(msg *messages.IQ, client *utils.Client) bool {
|
||||||
|
|
||||||
// query encode
|
// query encode
|
||||||
q := &iqPrivateQuery{}
|
q := &iqPrivateQuery{}
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
if err := xml.Unmarshal(msg.Body, q); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,7 @@ func (ex *IQPrivateBookmark) Handle(msg *messages.IQ, q *iqPrivateQuery, client
|
||||||
XMLName xml.Name `xml:"storage:bookmarks storage"`
|
XMLName xml.Name `xml:"storage:bookmarks storage"`
|
||||||
}
|
}
|
||||||
s := &storage{}
|
s := &storage{}
|
||||||
err := xml.Unmarshal(q.Body, s)
|
if err := xml.Unmarshal(q.Body, s); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -19,8 +19,7 @@ func (ex *IQPrivateMetacontact) Handle(msg *messages.IQ, q *iqPrivateQuery, clie
|
||||||
XMLName xml.Name `xml:"storage:metacontacts storage"`
|
XMLName xml.Name `xml:"storage:metacontacts storage"`
|
||||||
}
|
}
|
||||||
s := &storage{}
|
s := &storage{}
|
||||||
err := xml.Unmarshal(q.Body, s)
|
if err := xml.Unmarshal(q.Body, s); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -20,8 +20,7 @@ func (ex *IQPrivateRoster) Handle(msg *messages.IQ, q *iqPrivateQuery, client *u
|
||||||
Body []byte `xml:",innerxml"`
|
Body []byte `xml:",innerxml"`
|
||||||
}
|
}
|
||||||
r := &roster{}
|
r := &roster{}
|
||||||
err := xml.Unmarshal(q.Body, r)
|
if err := xml.Unmarshal(q.Body, r); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,7 @@ func (ex *IQRoster) Get(msg *messages.IQ, client *utils.Client) bool {
|
||||||
Body []byte `xml:",innerxml"`
|
Body []byte `xml:",innerxml"`
|
||||||
}
|
}
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
if err := xml.Unmarshal(msg.Body, q); err != nil {
|
||||||
if err != nil {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,21 +8,24 @@ import (
|
||||||
"github.com/genofire/yaja/model"
|
"github.com/genofire/yaja/model"
|
||||||
"github.com/genofire/yaja/server/extension"
|
"github.com/genofire/yaja/server/extension"
|
||||||
"github.com/genofire/yaja/server/toclient"
|
"github.com/genofire/yaja/server/toclient"
|
||||||
|
"github.com/genofire/yaja/server/toserver"
|
||||||
"github.com/genofire/yaja/server/utils"
|
"github.com/genofire/yaja/server/utils"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
TLSManager *autocert.Manager
|
TLSManager *autocert.Manager
|
||||||
ClientAddr []string
|
ClientAddr []string
|
||||||
ServerAddr []string
|
ServerAddr []string
|
||||||
Database *database.State
|
Database *database.State
|
||||||
LoggingClient log.Level
|
LoggingClient log.Level
|
||||||
RegisterEnable bool
|
LoggingServer log.Level
|
||||||
RegisterDomains []string
|
RegisterEnable bool
|
||||||
Extensions extension.Extensions
|
RegisterDomains []string
|
||||||
|
ExtensionsClient extension.Extensions
|
||||||
|
ExtensionsServer extension.Extensions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Start() {
|
func (srv *Server) Start() {
|
||||||
|
@ -69,15 +72,33 @@ func (srv *Server) listenClient(c2s net.Listener) {
|
||||||
|
|
||||||
func (srv *Server) handleServer(conn net.Conn) {
|
func (srv *Server) handleServer(conn net.Conn) {
|
||||||
log.Info("new server connection:", conn.RemoteAddr())
|
log.Info("new server connection:", conn.RemoteAddr())
|
||||||
|
|
||||||
|
client := utils.NewClient(conn, srv.LoggingClient)
|
||||||
|
client.Log = client.Log.WithField("c", "s2s")
|
||||||
|
|
||||||
|
state := toserver.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.ExtensionsServer, client)
|
||||||
|
|
||||||
|
for {
|
||||||
|
state = state.Process()
|
||||||
|
if state == nil {
|
||||||
|
client.Log.Info("disconnect")
|
||||||
|
client.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// run next state
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) handleClient(conn net.Conn) {
|
func (srv *Server) handleClient(conn net.Conn) {
|
||||||
log.Info("new client connection:", conn.RemoteAddr())
|
log.Info("new client connection:", conn.RemoteAddr())
|
||||||
client := utils.NewClient(conn, srv.LoggingClient)
|
|
||||||
state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed, srv.Extensions)
|
client := utils.NewClient(conn, srv.LoggingServer)
|
||||||
|
client.Log = client.Log.WithField("c", "c2s")
|
||||||
|
|
||||||
|
state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed, srv.ExtensionsClient, client)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
state, client = state.Process(client)
|
state = state.Process()
|
||||||
if state == nil {
|
if state == nil {
|
||||||
client.Log.Info("disconnect")
|
client.Log.Info("disconnect")
|
||||||
client.Close()
|
client.Close()
|
||||||
|
|
|
@ -10,116 +10,107 @@ import (
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionStartup return steps through TCP TLS state
|
|
||||||
func ConnectionStartup(after State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager) State {
|
|
||||||
tlsupgrade := &TLSUpgrade{
|
|
||||||
Next: after,
|
|
||||||
tlsconfig: tlsconfig,
|
|
||||||
tlsmgmt: tlsmgmt,
|
|
||||||
}
|
|
||||||
stream := &Start{Next: tlsupgrade}
|
|
||||||
return stream
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start state
|
// Start state
|
||||||
type Start struct {
|
type Start struct {
|
||||||
Next State
|
Next State
|
||||||
|
Client *utils.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message
|
// Process message
|
||||||
func (state *Start) Process(client *utils.Client) (State, *utils.Client) {
|
func (state *Start) Process() State {
|
||||||
client.Log = client.Log.WithField("state", "stream")
|
state.Client.Log = state.Client.Log.WithField("state", "stream")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
element, err := client.Read()
|
element, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
client.Log.Warn("is no stream")
|
state.Client.Log.Warn("is no stream")
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
for _, attr := range element.Attr {
|
for _, attr := range element.Attr {
|
||||||
if attr.Name.Local == "to" {
|
if attr.Name.Local == "to" {
|
||||||
client.JID = &model.JID{Domain: attr.Value}
|
state.Client.JID = &model.JID{Domain: attr.Value}
|
||||||
client.Log = client.Log.WithField("jid", client.JID.Full())
|
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if client.JID == nil {
|
if state.Client.JID == nil {
|
||||||
client.Log.Warn("no 'to' domain readed")
|
state.Client.Log.Warn("no 'to' domain readed")
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
|
||||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
fmt.Fprintf(state.Client.Conn, `<stream:features>
|
||||||
<starttls xmlns='%s'>
|
<starttls xmlns='%s'>
|
||||||
<required/>
|
<required/>
|
||||||
</starttls>
|
</starttls>
|
||||||
</stream:features>`,
|
</stream:features>`,
|
||||||
messages.NSStream)
|
messages.NSStream)
|
||||||
|
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSUpgrade state
|
// TLSUpgrade state
|
||||||
type TLSUpgrade struct {
|
type TLSUpgrade struct {
|
||||||
Next State
|
Next State
|
||||||
tlsconfig *tls.Config
|
Client *utils.Client
|
||||||
tlsmgmt *autocert.Manager
|
TLSConfig *tls.Config
|
||||||
|
TLSManager *autocert.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message
|
// Process message
|
||||||
func (state *TLSUpgrade) Process(client *utils.Client) (State, *utils.Client) {
|
func (state *TLSUpgrade) Process() State {
|
||||||
client.Log = client.Log.WithField("state", "tls upgrade")
|
state.Client.Log = state.Client.Log.WithField("state", "tls upgrade")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
element, err := client.Read()
|
element, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
|
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
|
||||||
client.Log.Warn("is no starttls")
|
state.Client.Log.Warn("is no starttls", element)
|
||||||
return state, client
|
return nil
|
||||||
}
|
}
|
||||||
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
|
fmt.Fprintf(state.Client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
|
||||||
// perform the TLS handshake
|
// perform the TLS handshake
|
||||||
var tlsConfig *tls.Config
|
var tlsConfig *tls.Config
|
||||||
if m := state.tlsmgmt; m != nil {
|
if m := state.TLSManager; m != nil {
|
||||||
var cert *tls.Certificate
|
var cert *tls.Certificate
|
||||||
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.JID.Domain})
|
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: state.Client.JID.Domain})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("no cert in tls manger found: ", err)
|
state.Client.Log.Warn("no cert in tls manger found: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
tlsConfig = &tls.Config{
|
tlsConfig = &tls.Config{
|
||||||
Certificates: []tls.Certificate{*cert},
|
Certificates: []tls.Certificate{*cert},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if tlsConfig == nil {
|
if tlsConfig == nil {
|
||||||
tlsConfig = state.tlsconfig
|
tlsConfig = state.TLSConfig
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
tlsConfig.ServerName = client.JID.Domain
|
tlsConfig.ServerName = state.Client.JID.Domain
|
||||||
} else {
|
} else {
|
||||||
client.Log.Warn("no tls config found: ", err)
|
state.Client.Log.Warn("no tls config found: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConn := tls.Server(client.Conn, tlsConfig)
|
tlsConn := tls.Server(state.Client.Conn, tlsConfig)
|
||||||
err = tlsConn.Handshake()
|
err = tlsConn.Handshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to tls handshake: ", err)
|
state.Client.Log.Warn("unable to tls handshake: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
// restart the Connection
|
// restart the Connection
|
||||||
client.SetConnecting(tlsConn)
|
state.Client.SetConnecting(tlsConn)
|
||||||
|
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/genofire/yaja/server/extension"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SendingClient state
|
||||||
|
type SendingClient struct {
|
||||||
|
Next State
|
||||||
|
Client *utils.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *SendingClient) Process() State {
|
||||||
|
state.Client.Log = state.Client.Log.WithField("state", "normal")
|
||||||
|
state.Client.Log.Debug("sending")
|
||||||
|
// sending
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case msg := <-state.Client.Messages:
|
||||||
|
err := state.Client.Out.Encode(msg)
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn(err)
|
||||||
|
}
|
||||||
|
case <-state.Client.OnClose():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
state.Client.Log.Debug("receiving")
|
||||||
|
return state.Next
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivingClient state
|
||||||
|
type ReceivingClient struct {
|
||||||
|
Extensions extension.Extensions
|
||||||
|
Client *utils.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *ReceivingClient) Process() State {
|
||||||
|
element, err := state.Client.Read()
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.Extensions.Process(element, state.Client)
|
||||||
|
return state
|
||||||
|
}
|
|
@ -4,5 +4,27 @@ import "github.com/genofire/yaja/server/utils"
|
||||||
|
|
||||||
// State processes the stream and moves to the next state
|
// State processes the stream and moves to the next state
|
||||||
type State interface {
|
type State interface {
|
||||||
Process(client *utils.Client) (State, *utils.Client)
|
Process() State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start state
|
||||||
|
type Debug struct {
|
||||||
|
Next State
|
||||||
|
Client *utils.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *Debug) Process() State {
|
||||||
|
state.Client.Log = state.Client.Log.WithField("state", "debug")
|
||||||
|
state.Client.Log.Debug("running")
|
||||||
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := state.Client.Read()
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state.Client.Log.Info(element)
|
||||||
|
|
||||||
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,51 +16,60 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionStartup return steps through TCP TLS state
|
// ConnectionStartup return steps through TCP TLS state
|
||||||
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions []extension.Extension) state.State {
|
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed, extensions extension.Extensions, c *utils.Client) state.State {
|
||||||
receiving := &ReceivingClient{Extensions: extensions}
|
receiving := &state.ReceivingClient{Extensions: extensions, Client: c}
|
||||||
sending := &SendingClient{Next: receiving}
|
sending := &state.SendingClient{Next: receiving, Client: c}
|
||||||
authedstream := &AuthedStream{Next: sending}
|
authedstream := &AuthedStream{Next: sending, Client: c}
|
||||||
authedstart := &AuthedStart{Next: authedstream}
|
authedstart := &AuthedStart{Next: authedstream, Client: c}
|
||||||
tlsauth := &SASLAuth{
|
tlsauth := &SASLAuth{
|
||||||
Next: authedstart,
|
Next: authedstart,
|
||||||
|
Client: c,
|
||||||
database: db,
|
database: db,
|
||||||
domainRegisterAllowed: registerAllowed,
|
domainRegisterAllowed: registerAllowed,
|
||||||
}
|
}
|
||||||
tlsstream := &TLSStream{
|
tlsstream := &TLSStream{
|
||||||
Next: tlsauth,
|
Next: tlsauth,
|
||||||
|
Client: c,
|
||||||
domainRegisterAllowed: registerAllowed,
|
domainRegisterAllowed: registerAllowed,
|
||||||
}
|
}
|
||||||
return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt)
|
tlsupgrade := &state.TLSUpgrade{
|
||||||
|
Next: tlsstream,
|
||||||
|
Client: c,
|
||||||
|
TLSConfig: tlsconfig,
|
||||||
|
TLSManager: tlsmgmt,
|
||||||
|
}
|
||||||
|
return &state.Start{Next: tlsupgrade, Client: c}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSStream state
|
// TLSStream state
|
||||||
type TLSStream struct {
|
type TLSStream struct {
|
||||||
Next state.State
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
domainRegisterAllowed utils.DomainRegisterAllowed
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages
|
// Process messages
|
||||||
func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) {
|
func (state *TLSStream) Process() state.State {
|
||||||
client.Log = client.Log.WithField("state", "tls stream")
|
state.Client.Log = state.Client.Log.WithField("state", "tls stream")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
element, err := client.Read()
|
element, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
client.Log.Warn("is no stream")
|
state.Client.Log.Warn("is no stream")
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
|
||||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
if state.domainRegisterAllowed(client.JID) {
|
if state.domainRegisterAllowed(state.Client.JID) {
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
fmt.Fprintf(state.Client.Conn, `<stream:features>
|
||||||
<mechanisms xmlns='%s'>
|
<mechanisms xmlns='%s'>
|
||||||
<mechanism>PLAIN</mechanism>
|
<mechanism>PLAIN</mechanism>
|
||||||
</mechanisms>
|
</mechanisms>
|
||||||
|
@ -68,7 +77,7 @@ func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Clien
|
||||||
</stream:features>`,
|
</stream:features>`,
|
||||||
messages.NSSASL, messages.NSFeaturesIQRegister)
|
messages.NSSASL, messages.NSFeaturesIQRegister)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
fmt.Fprintf(state.Client.Conn, `<stream:features>
|
||||||
<mechanisms xmlns='%s'>
|
<mechanisms xmlns='%s'>
|
||||||
<mechanism>PLAIN</mechanism>
|
<mechanism>PLAIN</mechanism>
|
||||||
</mechanisms>
|
</mechanisms>
|
||||||
|
@ -76,124 +85,129 @@ func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Clien
|
||||||
messages.NSSASL)
|
messages.NSSASL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
||||||
// SASLAuth state
|
// SASLAuth state
|
||||||
type SASLAuth struct {
|
type SASLAuth struct {
|
||||||
Next state.State
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
database *database.State
|
database *database.State
|
||||||
domainRegisterAllowed utils.DomainRegisterAllowed
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages
|
// Process messages
|
||||||
func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) {
|
func (state *SASLAuth) Process() state.State {
|
||||||
client.Log = client.Log.WithField("state", "sasl auth")
|
state.Client.Log = state.Client.Log.WithField("state", "sasl auth")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
// read the full auth stanza
|
// read the full auth stanza
|
||||||
element, err := client.Read()
|
element, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
var auth messages.SASLAuth
|
var auth messages.SASLAuth
|
||||||
if err = client.In.DecodeElement(&auth, element); err != nil {
|
if err = state.Client.In.DecodeElement(&auth, element); err != nil {
|
||||||
client.Log.Info("start substate for registration")
|
state.Client.Log.Info("start substate for registration")
|
||||||
return &RegisterFormRequest{
|
return &RegisterFormRequest{
|
||||||
|
Next: &RegisterRequest{
|
||||||
|
Next: state.Next,
|
||||||
|
Client: state.Client,
|
||||||
|
database: state.database,
|
||||||
|
domainRegisterAllowed: state.domainRegisterAllowed,
|
||||||
|
},
|
||||||
|
Client: state.Client,
|
||||||
element: element,
|
element: element,
|
||||||
domainRegisterAllowed: state.domainRegisterAllowed,
|
domainRegisterAllowed: state.domainRegisterAllowed,
|
||||||
Next: &RegisterRequest{
|
}
|
||||||
domainRegisterAllowed: state.domainRegisterAllowed,
|
|
||||||
database: state.database,
|
|
||||||
Next: state.Next,
|
|
||||||
},
|
|
||||||
}, client
|
|
||||||
}
|
}
|
||||||
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("body decode: ", err)
|
state.Client.Log.Warn("body decode: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
info := strings.Split(string(data), "\x00")
|
info := strings.Split(string(data), "\x00")
|
||||||
// should check that info[1] starts with client.JID
|
// should check that info[1] starts with state.Client.JID
|
||||||
client.JID.Local = info[1]
|
state.Client.JID.Local = info[1]
|
||||||
client.Log = client.Log.WithField("jid", client.JID.Full())
|
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
|
||||||
success, err := state.database.Authenticate(client.JID, info[2])
|
success, err := state.database.Authenticate(state.Client.JID, info[2])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("auth: ", err)
|
state.Client.Log.Warn("auth: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if success {
|
if success {
|
||||||
client.Log.Info("success auth")
|
state.Client.Log.Info("success auth")
|
||||||
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
fmt.Fprintf(state.Client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
client.Log.Warn("failed auth")
|
state.Client.Log.Warn("failed auth")
|
||||||
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
|
fmt.Fprintf(state.Client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
|
||||||
return nil, client
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthedStart state
|
// AuthedStart state
|
||||||
type AuthedStart struct {
|
type AuthedStart struct {
|
||||||
Next state.State
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages
|
// Process messages
|
||||||
func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) {
|
func (state *AuthedStart) Process() state.State {
|
||||||
client.Log = client.Log.WithField("state", "authed started")
|
state.Client.Log = state.Client.Log.WithField("state", "authed started")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
_, err := client.Read()
|
_, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
|
||||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
fmt.Fprintf(state.Client.Conn, `<stream:features>
|
||||||
<bind xmlns='%s'/>
|
<bind xmlns='%s'/>
|
||||||
</stream:features>`,
|
</stream:features>`,
|
||||||
messages.NSBind)
|
messages.NSBind)
|
||||||
|
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthedStream state
|
// AuthedStream state
|
||||||
type AuthedStream struct {
|
type AuthedStream struct {
|
||||||
Next state.State
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages
|
// Process messages
|
||||||
func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) {
|
func (state *AuthedStream) Process() state.State {
|
||||||
client.Log = client.Log.WithField("state", "authed stream")
|
state.Client.Log = state.Client.Log.WithField("state", "authed stream")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
// check that it's a bind request
|
// check that it's a bind request
|
||||||
// read bind request
|
// read bind request
|
||||||
element, err := client.Read()
|
element, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
var msg messages.IQ
|
var msg messages.IQ
|
||||||
if err = client.In.DecodeElement(&msg, element); err != nil {
|
if err = state.Client.In.DecodeElement(&msg, element); err != nil {
|
||||||
client.Log.Warn("is no iq: ", err)
|
state.Client.Log.Warn("is no iq: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if msg.Type != messages.IQTypeSet {
|
if msg.Type != messages.IQTypeSet {
|
||||||
client.Log.Warn("is no set iq")
|
state.Client.Log.Warn("is no set iq")
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if msg.Error != nil {
|
if msg.Error != nil {
|
||||||
client.Log.Warn("iq with error: ", msg.Error.Code)
|
state.Client.Log.Warn("iq with error: ", msg.Error.Code)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
type query struct {
|
type query struct {
|
||||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
|
||||||
|
@ -202,26 +216,26 @@ func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Cl
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err = xml.Unmarshal(msg.Body, q)
|
err = xml.Unmarshal(msg.Body, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("is no iq bind: ", err)
|
state.Client.Log.Warn("is no iq bind: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
if q.Resource == "" {
|
if q.Resource == "" {
|
||||||
client.JID.Resource = makeResource()
|
state.Client.JID.Resource = makeResource()
|
||||||
} else {
|
} else {
|
||||||
client.JID.Resource = q.Resource
|
state.Client.JID.Resource = q.Resource
|
||||||
}
|
}
|
||||||
client.Log = client.Log.WithField("jid", client.JID.Full())
|
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
|
||||||
client.Out.Encode(&messages.IQ{
|
state.Client.Out.Encode(&messages.IQ{
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
To: client.JID.String(),
|
To: state.Client.JID.String(),
|
||||||
From: client.JID.Domain,
|
From: state.Client.JID.Domain,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
Body: []byte(fmt.Sprintf(
|
Body: []byte(fmt.Sprintf(
|
||||||
`<bind xmlns='%s'>
|
`<bind xmlns='%s'>
|
||||||
<jid>%s</jid>
|
<jid>%s</jid>
|
||||||
</bind>`,
|
</bind>`,
|
||||||
messages.NSBind, client.JID.Full())),
|
messages.NSBind, state.Client.JID.Full())),
|
||||||
})
|
})
|
||||||
|
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,48 +0,0 @@
|
||||||
package toclient
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/genofire/yaja/server/extension"
|
|
||||||
"github.com/genofire/yaja/server/state"
|
|
||||||
"github.com/genofire/yaja/server/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SendingClient state
|
|
||||||
type SendingClient struct {
|
|
||||||
Next state.State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) {
|
|
||||||
client.Log = client.Log.WithField("state", "normal")
|
|
||||||
client.Log.Debug("sending")
|
|
||||||
// sending
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case msg := <-client.Messages:
|
|
||||||
err := client.Out.Encode(msg)
|
|
||||||
if err != nil {
|
|
||||||
client.Log.Warn(err)
|
|
||||||
}
|
|
||||||
case <-client.OnClose():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
client.Log.Debug("receiving")
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceivingClient state
|
|
||||||
type ReceivingClient struct {
|
|
||||||
Extensions extension.Extensions
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) {
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.Log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
state.Extensions.Process(element, client)
|
|
||||||
return state, client
|
|
||||||
}
|
|
|
@ -13,33 +13,34 @@ import (
|
||||||
|
|
||||||
type RegisterFormRequest struct {
|
type RegisterFormRequest struct {
|
||||||
Next state.State
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
domainRegisterAllowed utils.DomainRegisterAllowed
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
element *xml.StartElement
|
element *xml.StartElement
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message
|
// Process message
|
||||||
func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *utils.Client) {
|
func (state *RegisterFormRequest) Process() state.State {
|
||||||
client.Log = client.Log.WithField("state", "register form request")
|
state.Client.Log = state.Client.Log.WithField("state", "register form request")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
if !state.domainRegisterAllowed(client.JID) {
|
if !state.domainRegisterAllowed(state.Client.JID) {
|
||||||
client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
|
state.Client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg messages.IQ
|
var msg messages.IQ
|
||||||
if err := client.In.DecodeElement(&msg, state.element); err != nil {
|
if err := state.Client.In.DecodeElement(&msg, state.element); err != nil {
|
||||||
client.Log.Warn("is no iq: ", err)
|
state.Client.Log.Warn("is no iq: ", err)
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
if msg.Type != messages.IQTypeGet {
|
if msg.Type != messages.IQTypeGet {
|
||||||
client.Log.Warn("is no get iq")
|
state.Client.Log.Warn("is no get iq")
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
if msg.Error != nil {
|
if msg.Error != nil {
|
||||||
client.Log.Warn("iq with error: ", msg.Error.Code)
|
state.Client.Log.Warn("iq with error: ", msg.Error.Code)
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
type query struct {
|
type query struct {
|
||||||
XMLName xml.Name `xml:"query"`
|
XMLName xml.Name `xml:"query"`
|
||||||
|
@ -48,13 +49,13 @@ func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *u
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
err := xml.Unmarshal(msg.Body, q)
|
||||||
|
|
||||||
if q.XMLName.Space != messages.NSIQRegister || err != nil {
|
if q.XMLName.Space != messages.NSIQRegister || err != nil {
|
||||||
client.Log.Warn("is no iq register: ", err)
|
state.Client.Log.Warn("is no iq register: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
client.Out.Encode(&messages.IQ{
|
state.Client.Out.Encode(&messages.IQ{
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
To: client.JID.String(),
|
To: state.Client.JID.String(),
|
||||||
From: client.JID.Domain,
|
From: state.Client.JID.Domain,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
|
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
|
||||||
Choose a username and password for use with this service.
|
Choose a username and password for use with this service.
|
||||||
|
@ -63,43 +64,44 @@ func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *u
|
||||||
<password/>
|
<password/>
|
||||||
</query>`, messages.NSIQRegister)),
|
</query>`, messages.NSIQRegister)),
|
||||||
})
|
})
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
||||||
type RegisterRequest struct {
|
type RegisterRequest struct {
|
||||||
Next state.State
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
database *database.State
|
database *database.State
|
||||||
domainRegisterAllowed utils.DomainRegisterAllowed
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message
|
// Process message
|
||||||
func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils.Client) {
|
func (state *RegisterRequest) Process() state.State {
|
||||||
client.Log = client.Log.WithField("state", "register request")
|
state.Client.Log = state.Client.Log.WithField("state", "register request")
|
||||||
client.Log.Debug("running")
|
state.Client.Log.Debug("running")
|
||||||
defer client.Log.Debug("leave")
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
if !state.domainRegisterAllowed(client.JID) {
|
if !state.domainRegisterAllowed(state.Client.JID) {
|
||||||
client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
|
state.Client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
element, err := client.Read()
|
element, err := state.Client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("unable to read: ", err)
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
var msg messages.IQ
|
var msg messages.IQ
|
||||||
if err = client.In.DecodeElement(&msg, element); err != nil {
|
if err = state.Client.In.DecodeElement(&msg, element); err != nil {
|
||||||
client.Log.Warn("is no iq: ", err)
|
state.Client.Log.Warn("is no iq: ", err)
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
if msg.Type != messages.IQTypeGet {
|
if msg.Type != messages.IQTypeGet {
|
||||||
client.Log.Warn("is no get iq")
|
state.Client.Log.Warn("is no get iq")
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
if msg.Error != nil {
|
if msg.Error != nil {
|
||||||
client.Log.Warn("iq with error: ", msg.Error.Code)
|
state.Client.Log.Warn("iq with error: ", msg.Error.Code)
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
type query struct {
|
type query struct {
|
||||||
XMLName xml.Name `xml:"query"`
|
XMLName xml.Name `xml:"query"`
|
||||||
|
@ -109,19 +111,19 @@ func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err = xml.Unmarshal(msg.Body, q)
|
err = xml.Unmarshal(msg.Body, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Log.Warn("is no iq register: ", err)
|
state.Client.Log.Warn("is no iq register: ", err)
|
||||||
return nil, client
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
client.JID.Local = q.Username
|
state.Client.JID.Local = q.Username
|
||||||
client.Log = client.Log.WithField("jid", client.JID.Full())
|
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
|
||||||
account := model.NewAccount(client.JID, q.Password)
|
account := model.NewAccount(state.Client.JID, q.Password)
|
||||||
err = state.database.AddAccount(account)
|
err = state.database.AddAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.Out.Encode(&messages.IQ{
|
state.Client.Out.Encode(&messages.IQ{
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
To: client.JID.String(),
|
To: state.Client.JID.String(),
|
||||||
From: client.JID.Domain,
|
From: state.Client.JID.Domain,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
|
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
|
||||||
<username>%s</username>
|
<username>%s</username>
|
||||||
|
@ -136,16 +138,16 @@ func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
client.Log.Warn("database error: ", err)
|
state.Client.Log.Warn("database error: ", err)
|
||||||
return state, client
|
return state
|
||||||
}
|
}
|
||||||
client.Out.Encode(&messages.IQ{
|
state.Client.Out.Encode(&messages.IQ{
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
To: client.JID.String(),
|
To: state.Client.JID.String(),
|
||||||
From: client.JID.Domain,
|
From: state.Client.JID.Domain,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
})
|
})
|
||||||
|
|
||||||
client.Log.Infof("registered client %s", client.JID.Bare())
|
state.Client.Log.Infof("registered client %s", state.Client.JID.Bare())
|
||||||
return state.Next, client
|
return state.Next
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,141 @@
|
||||||
|
package toserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/database"
|
||||||
|
"github.com/genofire/yaja/messages"
|
||||||
|
"github.com/genofire/yaja/server/extension"
|
||||||
|
"github.com/genofire/yaja/server/state"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionStartup return steps through TCP TLS state
|
||||||
|
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, extensions extension.Extensions, c *utils.Client) state.State {
|
||||||
|
receiving := &state.ReceivingClient{Extensions: extensions, Client: c}
|
||||||
|
sending := &state.SendingClient{Next: receiving, Client: c}
|
||||||
|
tlsstream := &TLSStream{
|
||||||
|
Next: sending,
|
||||||
|
Client: c,
|
||||||
|
}
|
||||||
|
tlsupgrade := &state.TLSUpgrade{
|
||||||
|
Next: tlsstream,
|
||||||
|
Client: c,
|
||||||
|
TLSConfig: tlsconfig,
|
||||||
|
TLSManager: tlsmgmt,
|
||||||
|
}
|
||||||
|
dail := &Dailback{
|
||||||
|
Next: tlsupgrade,
|
||||||
|
Client: c,
|
||||||
|
}
|
||||||
|
return &state.Start{Next: dail, Client: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSStream state
|
||||||
|
type Dailback struct {
|
||||||
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *Dailback) Process() state.State {
|
||||||
|
state.Client.Log = state.Client.Log.WithField("state", "dialback")
|
||||||
|
state.Client.Log.Debug("running")
|
||||||
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := state.Client.Read()
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dailback encode
|
||||||
|
type dailback struct {
|
||||||
|
XMLName xml.Name `xml:"urn:xmpp:ping ping"`
|
||||||
|
}
|
||||||
|
db := &dailback{}
|
||||||
|
if err = state.Client.In.DecodeElement(db, element); err != nil {
|
||||||
|
return state.Next
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Client.Log.Info(db)
|
||||||
|
return state.Next
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSStream state
|
||||||
|
type TLSStream struct {
|
||||||
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *TLSStream) Process() state.State {
|
||||||
|
state.Client.Log = state.Client.Log.WithField("state", "tls stream")
|
||||||
|
state.Client.Log.Debug("running")
|
||||||
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := state.Client.Read()
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
|
state.Client.Log.Warn("is no stream")
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
fmt.Fprintf(state.Client.Conn, `<stream:features>
|
||||||
|
<mechanisms xmlns='%s'>
|
||||||
|
<mechanism>EXTERNAL</mechanism>
|
||||||
|
</mechanisms>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSSASL)
|
||||||
|
|
||||||
|
return state.Next
|
||||||
|
}
|
||||||
|
|
||||||
|
// SASLAuth state
|
||||||
|
type SASLAuth struct {
|
||||||
|
Next state.State
|
||||||
|
Client *utils.Client
|
||||||
|
database *database.State
|
||||||
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *SASLAuth) Process() state.State {
|
||||||
|
state.Client.Log = state.Client.Log.WithField("state", "sasl auth")
|
||||||
|
state.Client.Log.Debug("running")
|
||||||
|
defer state.Client.Log.Debug("leave")
|
||||||
|
|
||||||
|
// read the full auth stanza
|
||||||
|
element, err := state.Client.Read()
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var auth messages.SASLAuth
|
||||||
|
if err = state.Client.In.DecodeElement(&auth, element); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
||||||
|
if err != nil {
|
||||||
|
state.Client.Log.Warn("body decode: ", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state.Client.Log.Debug(auth.Mechanism, string(data))
|
||||||
|
|
||||||
|
state.Client.Log.Info("success auth")
|
||||||
|
fmt.Fprintf(state.Client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
||||||
|
return state.Next
|
||||||
|
}
|
|
@ -30,7 +30,7 @@ func NewClient(conn net.Conn, level log.Level) *Client {
|
||||||
Log: log.NewEntry(logger),
|
Log: log.NewEntry(logger),
|
||||||
In: xml.NewDecoder(conn),
|
In: xml.NewDecoder(conn),
|
||||||
Out: xml.NewEncoder(conn),
|
Out: xml.NewEncoder(conn),
|
||||||
Messages: make(chan interface{}, 1000),
|
Messages: make(chan interface{}),
|
||||||
close: make(chan interface{}),
|
close: make(chan interface{}),
|
||||||
}
|
}
|
||||||
return client
|
return client
|
||||||
|
|
Reference in New Issue