From 800a5b191748f72b2b2e18553a48ad20cbfb7d58 Mon Sep 17 00:00:00 2001 From: Martin Geno Date: Thu, 14 Dec 2017 21:30:07 +0100 Subject: [PATCH] lets encrypt + registration --- .drone.yml | 2 +- README.md | 25 ++- cmd/{serve.go => server.go} | 66 +++---- config_example.conf | 6 +- database/main.go | 68 +++++++ main.go | 2 +- messages/error.go | 12 ++ messages/iq.go | 25 +++ messages/namespaces.go | 19 ++ messages/presence.go | 32 ++++ messages/sasl.go | 10 ++ model/account.go | 26 ++- model/config/file.go | 24 --- model/config/file_test.go | 25 --- model/config/struct.go | 10 +- model/file.go | 27 --- model/jid.go | 2 + model/state.go | 38 ---- server/client.go | 62 +++++++ server/server.go | 71 +++++++- server/state.go | 6 + server/state_connect.go | 344 ++++++++++++++++++++++++++++++++++++ server/state_register.go | 130 ++++++++++++++ server/utils.go | 29 +++ 24 files changed, 899 insertions(+), 162 deletions(-) rename cmd/{serve.go => server.go} (63%) create mode 100644 database/main.go create mode 100644 messages/error.go create mode 100644 messages/iq.go create mode 100644 messages/namespaces.go create mode 100644 messages/presence.go create mode 100644 messages/sasl.go delete mode 100644 model/config/file.go delete mode 100644 model/config/file_test.go delete mode 100644 model/file.go delete mode 100644 model/state.go create mode 100644 server/client.go create mode 100644 server/state.go create mode 100644 server/state_connect.go create mode 100644 server/state_register.go create mode 100644 server/utils.go diff --git a/.drone.yml b/.drone.yml index c1d9103..cfe9754 100644 --- a/.drone.yml +++ b/.drone.yml @@ -1,6 +1,6 @@ workspace: base: /go - path: src/dev.sum7.eu/genofire/yaja + path: src/github.com/genofire/yaja pipeline: build: diff --git a/README.md b/README.md index 4edea69..6f6c8a1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,24 @@ -# yaja +# yaja (yet another jabber server) -Yet Another Jabber Server \ No newline at end of file +``` +A small standalone jabber server, for easy deployment + +Usage: + yaja [command] + +Available Commands: + help Help about any command + server Runs the yaja server + +Flags: + -h, --help help for yaja + +Use "yaja [command] --help" for more information about a command. +``` + +## Features (works already) +- get certificate by lets encrypt +- registration (for every possible ssl domain) + +## Inspiration + - [tam7t](https://github.com/tam7t/xmpp) a fork of [agl](https://github.com/agl)'s work diff --git a/cmd/serve.go b/cmd/server.go similarity index 63% rename from cmd/serve.go rename to cmd/server.go index 16d39ea..31612b5 100644 --- a/cmd/serve.go +++ b/cmd/server.go @@ -2,6 +2,7 @@ package cmd import ( "crypto/tls" + "net/http" "os" "os/signal" "syscall" @@ -9,11 +10,12 @@ import ( "golang.org/x/crypto/acme/autocert" - "dev.sum7.eu/genofire/yaja/model" - "dev.sum7.eu/genofire/yaja/model/config" + "github.com/genofire/yaja/database" + "github.com/genofire/yaja/model/config" - "dev.sum7.eu/genofire/yaja/server" + "github.com/genofire/golang-lib/file" "github.com/genofire/golang-lib/worker" + "github.com/genofire/yaja/server" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -22,32 +24,34 @@ import ( var configPath string var ( - configData *config.Config - state *model.State + configData = &config.Config{} + db = &database.State{} statesaveWorker *worker.Worker srv *server.Server certs *tls.Config ) -// serveCmd represents the serve command -var serveCmd = &cobra.Command{ - Use: "serve", +// serverCmd represents the serve command +var serverCmd = &cobra.Command{ + Use: "server", Short: "Runs the yaja server", Example: "yaja serve -c /etc/yaja.conf", Run: func(cmd *cobra.Command, args []string) { var err error - configData, err = config.ReadConfigFile(configPath) + err = file.ReadTOML(configPath, configData) if err != nil { log.Fatal("unable to load config file:", err) } - state, err = model.ReadState(configData.StatePath) + log.SetLevel(log.DebugLevel) + + err = file.ReadJSON(configData.StatePath, db) if err != nil { log.Warn("unable to load state file:", err) } statesaveWorker = worker.NewWorker(time.Minute, func() { - model.SaveJSON(state, configData.StatePath) + file.SaveJSON(configData.StatePath, db) log.Info("save state to:", configData.StatePath) }) @@ -56,13 +60,18 @@ var serveCmd = &cobra.Command{ Prompt: autocert.AcceptTOS, } - certs = &tls.Config{GetCertificate: m.GetCertificate} + // https server to handle acme (by letsencrypt) + httpServer := &http.Server{ + Addr: ":https", + TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, + } + go httpServer.ListenAndServeTLS("", "") srv = &server.Server{ - TLSConfig: certs, - State: state, - PortClient: configData.PortClient, - PortServer: configData.PortServer, + TLSManager: &m, + Database: db, + ClientAddr: configData.Address.Client, + ServerAddr: configData.Address.Server, } go statesaveWorker.Start() @@ -95,21 +104,24 @@ func quit() { srv.Close() statesaveWorker.Close() - model.SaveJSON(state, configData.StatePath) + file.SaveJSON(configData.StatePath, db) } func reload() { log.Info("start reloading...") - configNewData, err := config.ReadConfigFile(configPath) + var configNewData *config.Config + err := file.ReadTOML(configPath, configNewData) if err != nil { log.Warn("unable to load config file:", err) return } + //TODO fetch changing address (to set restart) + if configNewData.StatePath != configData.StatePath { statesaveWorker.Close() statesaveWorker := worker.NewWorker(time.Minute, func() { - model.SaveJSON(state, configNewData.StatePath) + file.SaveJSON(configNewData.StatePath, db) log.Info("save state to:", configNewData.StatePath) }) go statesaveWorker.Start() @@ -130,17 +142,11 @@ func reload() { newServer := &server.Server{ TLSConfig: certs, - State: state, - PortClient: configNewData.PortClient, - PortServer: configNewData.PortServer, + Database: db, + ClientAddr: configNewData.Address.Client, + ServerAddr: configNewData.Address.Server, } - if configNewData.PortServer != configData.PortServer { - restartServer = true - } - if configNewData.PortClient != configData.PortClient { - restartServer = true - } if restartServer { go srv.Start() //TODO should fetch new server error @@ -153,6 +159,6 @@ func reload() { } func init() { - RootCmd.AddCommand(serveCmd) - serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") + RootCmd.AddCommand(serverCmd) + serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") } diff --git a/config_example.conf b/config_example.conf index 75e686b..d33f8bf 100644 --- a/config_example.conf +++ b/config_example.conf @@ -1,4 +1,6 @@ tlsdir = "/tmp/ssl" state_path = "/tmp/yaja.json" -port_client = 5222 -port_server = 5269 + +[address] +client = [":5222"] +server = [":5269"] diff --git a/database/main.go b/database/main.go new file mode 100644 index 0000000..ba3892d --- /dev/null +++ b/database/main.go @@ -0,0 +1,68 @@ +package database + +import ( + "errors" + "sync" + + "github.com/genofire/yaja/model" + log "github.com/sirupsen/logrus" +) + +type State struct { + Domains map[string]*model.Domain `json:"domains"` + sync.Mutex +} + +func (s *State) AddAccount(a *model.Account) error { + if a.Local == "" { + return errors.New("No localpart exists in account") + } + if d := a.Domain; d != nil { + if d.FQDN == "" { + return errors.New("No fqdn exists in domain") + } + s.Lock() + domain, ok := s.Domains[d.FQDN] + if !ok { + if s.Domains == nil { + s.Domains = make(map[string]*model.Domain) + } + s.Domains[d.FQDN] = d + domain = d + } + s.Unlock() + + domain.Lock() + defer domain.Unlock() + if domain.Accounts == nil { + domain.Accounts = make(map[string]*model.Account) + } + _, ok = domain.Accounts[a.Local] + if ok { + return errors.New("exists already") + } + domain.Accounts[a.Local] = a + a.Domain = d + return nil + } + return errors.New("no give domain") +} + +func (s *State) Authenticate(jid *model.JID, password string) (bool, error) { + logger := log.WithField("database", "auth") + + if domain, ok := s.Domains[jid.Domain]; ok { + if acc, ok := domain.Accounts[jid.Local]; ok { + if acc.ValidatePassword(password) { + return true, nil + } else { + logger.Debug("password not valid") + } + } else { + logger.Debug("account not found") + } + } else { + logger.Debug("domain not found") + } + return false, nil +} diff --git a/main.go b/main.go index e7be7f7..2e00ee8 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,6 @@ package main -import "dev.sum7.eu/genofire/yaja/cmd" +import "github.com/genofire/yaja/cmd" func main() { cmd.Execute() diff --git a/messages/error.go b/messages/error.go new file mode 100644 index 0000000..6acc884 --- /dev/null +++ b/messages/error.go @@ -0,0 +1,12 @@ +package messages + +import "encoding/xml" + +// Error element +type Error struct { + XMLName xml.Name `xml:"jabber:client error"` + Code string `xml:"code,attr"` + Type string `xml:"type,attr"` + Any xml.Name `xml:",any"` + Text string `xml:"text"` +} diff --git a/messages/iq.go b/messages/iq.go new file mode 100644 index 0000000..cf278b0 --- /dev/null +++ b/messages/iq.go @@ -0,0 +1,25 @@ +package messages + +import "encoding/xml" + +type IQType string + +const ( + IQTypeGet IQType = "get" + IQTypeSet IQType = "set" + IQTypeResult IQType = "result" + IQTypeError IQType = "error" +) + +// IQ element - info/query +type IQ struct { + XMLName xml.Name `xml:"jabber:client iq"` + From string `xml:"from,attr"` + ID string `xml:"id,attr"` + To string `xml:"to,attr"` + Type IQType `xml:"type,attr"` + Error *Error `xml:"error"` + //Bind bindBind `xml:"bind"` + Body []byte `xml:",innerxml"` + // RosterRequest - better detection of iq's +} diff --git a/messages/namespaces.go b/messages/namespaces.go new file mode 100644 index 0000000..2ddc0dd --- /dev/null +++ b/messages/namespaces.go @@ -0,0 +1,19 @@ +package messages + +const ( + // NSStream stream namesapce + NSStream = "http://etherx.jabber.org/streams" + // NSTLS xmpp-tls xml namespace + NSTLS = "urn:ietf:params:xml:ns:xmpp-tls" + // NSSASL xmpp-sasl xml namespace + NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl" + + NSBind = "urn:ietf:params:xml:ns:xmpp-bind" + + // NSClient jabbet client namespace + NSClient = "jabber:client" + + NSIQRegister = "jabber:iq:register" + + NSFeaturesIQRegister = "http://jabber.org/features/iq-register" +) diff --git a/messages/presence.go b/messages/presence.go new file mode 100644 index 0000000..f77c90f --- /dev/null +++ b/messages/presence.go @@ -0,0 +1,32 @@ +package messages + +import "encoding/xml" + +type PresenceType string + +const ( + PresenceTypeUnavailable PresenceType = "unavailable" + PresenceTypeSubscribe PresenceType = "subscribe" + PresenceTypeSubscribed PresenceType = "subscribed" + PresenceTypeUnsubscribe PresenceType = "unsubscribe" + PresenceTypeUnsubscribed PresenceType = "unsubscribed" + PresenceTypeProbe PresenceType = "probe" + PresenceTypeError PresenceType = "error" +) + +// Presence element +type Presence struct { + XMLName xml.Name `xml:"jabber:client presence"` + From string `xml:"from,attr,omitempty"` + ID string `xml:"id,attr,omitempty"` + To string `xml:"to,attr,omitempty"` + Type string `xml:"type,attr,omitempty"` + Lang string `xml:"lang,attr,omitempty"` + + Show string `xml:"show,omitempty"` // away, chat, dnd, xa + Status string `xml:"status,omitempty"` // sb []clientText + Priority string `xml:"priority,omitempty"` + // Caps *ClientCaps `xml:"c"` + Error *Error `xml:"error"` + // Delay Delay `xml:"delay"` +} diff --git a/messages/sasl.go b/messages/sasl.go new file mode 100644 index 0000000..294bda1 --- /dev/null +++ b/messages/sasl.go @@ -0,0 +1,10 @@ +package messages + +import "encoding/xml" + +// SASLAuth element +type SASLAuth struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl auth"` + Mechanism string `xml:"mechanism,attr"` + Body string `xml:",chardata"` +} diff --git a/model/account.go b/model/account.go index 6c638af..d59a432 100644 --- a/model/account.go +++ b/model/account.go @@ -6,8 +6,8 @@ import ( ) type Domain struct { - FQDN string - Accounts map[string]*Account + FQDN string `json:"-"` + Accounts map[string]*Account `json:"users"` sync.Mutex } @@ -29,8 +29,22 @@ func (d *Domain) UpdateAccount(a *Account) error { } type Account struct { - Local string - Domain *Domain + Local string `json:"-"` + Domain *Domain `json:"-"` + Password string `json:"password"` +} + +func NewAccount(jid *JID, password string) *Account { + if jid == nil { + return nil + } + return &Account{ + Local: jid.Local, + Domain: &Domain{ + FQDN: jid.Domain, + }, + Password: password, + } } func (a *Account) GetJID() *JID { @@ -39,3 +53,7 @@ func (a *Account) GetJID() *JID { Local: a.Local, } } + +func (a *Account) ValidatePassword(password string) bool { + return a.Password == password +} diff --git a/model/config/file.go b/model/config/file.go deleted file mode 100644 index 29609da..0000000 --- a/model/config/file.go +++ /dev/null @@ -1,24 +0,0 @@ -package config - -import ( - "io/ioutil" - - "github.com/BurntSushi/toml" -) - -// ReadConfigFile reads a config model from path of a yml file -func ReadConfigFile(path string) (config *Config, err error) { - config = &Config{} - - file, err := ioutil.ReadFile(path) - if err != nil { - return nil, err - } - - err = toml.Unmarshal(file, config) - if err != nil { - return nil, err - } - - return -} diff --git a/model/config/file_test.go b/model/config/file_test.go deleted file mode 100644 index 3368c04..0000000 --- a/model/config/file_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestReadConfig(t *testing.T) { - assert := assert.New(t) - - config, err := ReadConfigFile("../../config_example.conf") - assert.NoError(err) - assert.NotNil(config) - - assert.Equal("/tmp/ssl", config.TLSDir) - - config, err = ReadConfigFile("../config_example.co") - assert.Nil(config) - assert.Contains(err.Error(), "no such file or directory") - - config, err = ReadConfigFile("testdata/config_panic.conf") - assert.Nil(config) - assert.Contains(err.Error(), "keys cannot contain") -} diff --git a/model/config/struct.go b/model/config/struct.go index 83fe37f..e902bdf 100644 --- a/model/config/struct.go +++ b/model/config/struct.go @@ -1,8 +1,10 @@ package config type Config struct { - TLSDir string `toml:"tlsdir"` - StatePath string `toml:"state_path"` - PortClient int `toml:"port_client"` - PortServer int `toml:"port_server"` + TLSDir string `toml:"tlsdir"` + StatePath string `toml:"state_path"` + Address struct { + Client []string `toml:"client"` + Server []string `toml:"server"` + } `toml:"address"` } diff --git a/model/file.go b/model/file.go deleted file mode 100644 index 493f62d..0000000 --- a/model/file.go +++ /dev/null @@ -1,27 +0,0 @@ -package model - -import ( - "encoding/json" - "log" - "os" -) - -// SaveJSON to path -func SaveJSON(input interface{}, outputFile string) { - tmpFile := outputFile + ".tmp" - - f, err := os.OpenFile(tmpFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - log.Panic(err) - } - - err = json.NewEncoder(f).Encode(input) - if err != nil { - log.Panic(err) - } - - f.Close() - if err := os.Rename(tmpFile, outputFile); err != nil { - log.Panic(err) - } -} diff --git a/model/jid.go b/model/jid.go index 0c67867..7fc5b5c 100644 --- a/model/jid.go +++ b/model/jid.go @@ -42,6 +42,8 @@ func (jid *JID) Bare() string { return jid.Domain } +func (jid *JID) String() string { return jid.Bare() } + // Full get the "full" jid as string func (jid *JID) Full() string { if jid.Resource != "" { diff --git a/model/state.go b/model/state.go deleted file mode 100644 index 16cbed1..0000000 --- a/model/state.go +++ /dev/null @@ -1,38 +0,0 @@ -package model - -import ( - "encoding/json" - "errors" - "os" - "sync" -) - -type State struct { - Domains map[string]*Domain - sync.Mutex -} - -func ReadState(path string) (state *State, err error) { - state = &State{} - - if f, err := os.Open(path); err == nil { // transform data to legacy meshviewer - if err = json.NewDecoder(f).Decode(state); err == nil { - return state, nil - - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (s *State) UpdateDomain(d *Domain) error { - if d.FQDN == "" { - return errors.New("No fqdn exists in domain") - } - s.Lock() - s.Domains[d.FQDN] = d - s.Unlock() - return nil -} diff --git a/server/client.go b/server/client.go new file mode 100644 index 0000000..85555a0 --- /dev/null +++ b/server/client.go @@ -0,0 +1,62 @@ +package server + +import ( + "encoding/xml" + "net" + + "github.com/genofire/yaja/model" + log "github.com/sirupsen/logrus" +) + +type Client struct { + log *log.Entry + + Conn net.Conn + out *xml.Encoder + in *xml.Decoder + + Server *Server + jid *model.JID + account *model.Account + + messages chan interface{} + close chan interface{} +} + +func NewClient(conn net.Conn, srv *Server) *Client { + logger := log.New() + logger.SetLevel(log.DebugLevel) + client := &Client{ + Conn: conn, + Server: srv, + log: log.NewEntry(logger), + in: xml.NewDecoder(conn), + out: xml.NewEncoder(conn), + } + return client +} + +func (client *Client) NewConnecting(conn net.Conn) { + client.Conn = conn + client.in = xml.NewDecoder(conn) + client.out = xml.NewEncoder(conn) +} + +func (client *Client) Read() (*xml.StartElement, error) { + for { + nextToken, err := client.in.Token() + if err != nil { + return nil, err + } + switch nextToken.(type) { + case xml.StartElement: + element := nextToken.(xml.StartElement) + return &element, nil + } + } +} + +func (client *Client) Close() { + client.close <- true + client.Conn.Close() +} diff --git a/server/server.go b/server/server.go index 5298e87..7463200 100644 --- a/server/server.go +++ b/server/server.go @@ -2,19 +2,82 @@ package server import ( "crypto/tls" + "net" - "dev.sum7.eu/genofire/yaja/model" + "github.com/genofire/yaja/database" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/acme/autocert" ) type Server struct { TLSConfig *tls.Config - PortClient int - PortServer int - State *model.State + TLSManager *autocert.Manager + ClientAddr []string + ServerAddr []string + Database *database.State } func (srv *Server) Start() { + for _, addr := range srv.ServerAddr { + socket, err := net.Listen("tcp", addr) + if err != nil { + log.Warn("create server socket: ", err.Error()) + break + } + go srv.listenServer(socket) + } + for _, addr := range srv.ClientAddr { + socket, err := net.Listen("tcp", addr) + if err != nil { + log.Warn("create client socket: ", err.Error()) + break + } + go srv.listenClient(socket) + } +} + +func (srv *Server) listenServer(s2s net.Listener) { + for { + conn, err := s2s.Accept() + if err != nil { + log.Warn("accepting server connection: ", err.Error()) + break + } + go srv.handleServer(conn) + } +} + +func (srv *Server) listenClient(c2s net.Listener) { + for { + conn, err := c2s.Accept() + if err != nil { + log.Warn("accepting client connection: ", err.Error()) + break + } + go srv.handleClient(conn) + } +} + +func (srv *Server) handleServer(conn net.Conn) { + log.Info("new server connection:", conn.RemoteAddr()) +} + +func (srv *Server) handleClient(conn net.Conn) { + log.Info("new client connection:", conn.RemoteAddr()) + client := NewClient(conn, srv) + state := ConnectionStartup() + + for { + state, client = state.Process(client) + if state == nil { + client.log.Info("disconnect") + client.Close() + //s.DisconnectBus <- Disconnect{Jid: client.jid} + return + } + // run next state + } } func (srv *Server) Close() { diff --git a/server/state.go b/server/state.go new file mode 100644 index 0000000..22b2f13 --- /dev/null +++ b/server/state.go @@ -0,0 +1,6 @@ +package server + +// State processes the stream and moves to the next state +type State interface { + Process(client *Client) (State, *Client) +} diff --git a/server/state_connect.go b/server/state_connect.go new file mode 100644 index 0000000..c54b132 --- /dev/null +++ b/server/state_connect.go @@ -0,0 +1,344 @@ +package server + +import ( + "crypto/tls" + "encoding/base64" + "encoding/xml" + "fmt" + "strings" + + "github.com/genofire/yaja/messages" + "github.com/genofire/yaja/model" +) + +// ConnectionStartup return steps through TCP TLS state +func ConnectionStartup() State { + receiving := &ReceivingClient{} + sending := &SendingClient{Next: receiving} + authedstream := &AuthedStream{Next: sending} + authedstart := &AuthedStart{Next: authedstream} + tlsauth := &SASLAuth{Next: authedstart} + tlsstream := &TLSStream{Next: tlsauth} + tlsupgrade := &TLSUpgrade{Next: tlsstream} + stream := &Start{Next: tlsupgrade} + return stream +} + +// Start state +type Start struct { + Next State +} + +// Process message +func (state *Start) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "stream") + client.log.Debug("running") + defer client.log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { + client.log.Warn("is no stream") + return state, client + } + for _, attr := range element.Attr { + if attr.Name.Local == "to" { + client.jid = &model.JID{Domain: attr.Value} + client.log = client.log.WithField("jid", client.jid.Full()) + } + } + if client.jid == nil { + client.log.Warn("no 'to' domain readed") + return nil, client + } + + fmt.Fprintf(client.Conn, ` + `, + createCookie(), messages.NSClient, messages.NSStream) + + fmt.Fprintf(client.Conn, ` + + + + `, + messages.NSStream) + + return state.Next, client +} + +// TLSUpgrade state +type TLSUpgrade struct { + Next State +} + +// Process message +func (state *TLSUpgrade) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "tls upgrade") + client.log.Debug("running") + defer client.log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" { + client.log.Warn("is no starttls") + return state, client + } + fmt.Fprintf(client.Conn, "", messages.NSTLS) + // perform the TLS handshake + var tlsConfig *tls.Config + if m := client.Server.TLSManager; m != nil { + var cert *tls.Certificate + cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain}) + if err != nil { + client.log.Warn("no cert in tls manger found: ", err) + return nil, client + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + } + if tlsConfig == nil { + tlsConfig = client.Server.TLSConfig + if tlsConfig != nil { + tlsConfig.ServerName = client.jid.Domain + } else { + client.log.Warn("no tls config found: ", err) + return nil, client + } + } + + tlsConn := tls.Server(client.Conn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + client.log.Warn("unable to tls handshake: ", err) + return nil, client + } + // restart the Connection + client.NewConnecting(tlsConn) + + return state.Next, client +} + +// TLSStream state +type TLSStream struct { + Next State +} + +// Process messages +func (state *TLSStream) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "tls stream") + client.log.Debug("running") + defer client.log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { + client.log.Warn("is no stream") + return state, client + } + + fmt.Fprintf(client.Conn, ` + `, + createCookie(), messages.NSClient, messages.NSStream) + + fmt.Fprintf(client.Conn, ` + + PLAIN + + + `, + messages.NSSASL, messages.NSFeaturesIQRegister) + + return state.Next, client +} + +// SASLAuth state +type SASLAuth struct { + Next State +} + +// Process messages +func (state *SASLAuth) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "sasl auth") + client.log.Debug("running") + defer client.log.Debug("leave") + + // read the full auth stanza + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + var auth messages.SASLAuth + if err = client.in.DecodeElement(&auth, element); err != nil { + client.log.Info("start substate for registration") + return &RegisterFormRequest{ + element: element, + Next: &RegisterRequest{ + Next: state.Next, + }, + }, client + } + data, err := base64.StdEncoding.DecodeString(auth.Body) + if err != nil { + client.log.Warn("body decode: ", err) + return nil, client + } + info := strings.Split(string(data), "\x00") + // should check that info[1] starts with client.jid + client.jid.Local = info[1] + client.log = client.log.WithField("jid", client.jid.Full()) + success, err := client.Server.Database.Authenticate(client.jid, info[2]) + if err != nil { + client.log.Warn("auth: ", err) + return nil, client + } + if success { + client.log.Info("success auth") + fmt.Fprintf(client.Conn, "", messages.NSSASL) + return state.Next, client + } + client.log.Warn("failed auth") + fmt.Fprintf(client.Conn, "", messages.NSSASL) + return nil, client + +} + +// AuthedStart state +type AuthedStart struct { + Next State +} + +// Process messages +func (state *AuthedStart) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "authed started") + client.log.Debug("running") + defer client.log.Debug("leave") + + _, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + fmt.Fprintf(client.Conn, ` + `, + createCookie(), messages.NSClient, messages.NSStream) + + fmt.Fprintf(client.Conn, ` + + `, + messages.NSBind) + + return state.Next, client +} + +// AuthedStream state +type AuthedStream struct { + Next State +} + +// Process messages +func (state *AuthedStream) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "authed stream") + client.log.Debug("running") + defer client.log.Debug("leave") + + // check that it's a bind request + // read bind request + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + var msg messages.IQ + if err = client.in.DecodeElement(&msg, element); err != nil { + client.log.Warn("is no iq: ", err) + return nil, client + } + if msg.Type != messages.IQTypeSet { + client.log.Warn("is no set iq") + return nil, client + } + if msg.Error != nil { + client.log.Warn("iq with error: ", msg.Error.Code) + return nil, client + } + type query struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` + Resource string `xml:"resource"` + } + q := &query{} + err = xml.Unmarshal(msg.Body, q) + if err != nil { + client.log.Warn("is no iq bind: ", err) + return nil, client + } + if q.Resource == "" { + client.jid.Resource = makeResource() + } else { + client.jid.Resource = q.Resource + } + client.log = client.log.WithField("jid", client.jid.Full()) + client.out.Encode(&messages.IQ{ + Type: messages.IQTypeResult, + ID: msg.ID, + Body: []byte(fmt.Sprintf( + ` + %s + `, + messages.NSBind, client.jid.Full())), + }) + + return state.Next, client +} + +// SendingClient state +type SendingClient struct { + Next State +} + +// Process messages +func (state *SendingClient) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "normal") + client.log.Debug("sending") + // sending + go func() { + select { + case msg := <-client.messages: + err := client.out.Encode(msg) + client.log.Info(err) + case <-client.close: + return + } + }() + client.log.Debug("receiving") + return state.Next, client +} + +// ReceivingClient state +type ReceivingClient struct { +} + +// Process messages +func (state *ReceivingClient) Process(client *Client) (State, *Client) { + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + /* + for _, extension := range client.Server.Extensions { + extension.Process(element, client) + }*/ + client.log.Debug(element) + return state, client +} diff --git a/server/state_register.go b/server/state_register.go new file mode 100644 index 0000000..eb4dd4a --- /dev/null +++ b/server/state_register.go @@ -0,0 +1,130 @@ +package server + +import ( + "encoding/xml" + "fmt" + + "github.com/genofire/yaja/messages" + "github.com/genofire/yaja/model" +) + +type RegisterFormRequest struct { + Next State + element *xml.StartElement +} + +// Process message +func (state *RegisterFormRequest) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "register form request") + client.log.Debug("running") + defer client.log.Debug("leave") + + var msg messages.IQ + if err := client.in.DecodeElement(&msg, state.element); err != nil { + client.log.Warn("is no iq: ", err) + return state, client + } + if msg.Type != messages.IQTypeGet { + client.log.Warn("is no get iq") + return state, client + } + if msg.Error != nil { + client.log.Warn("iq with error: ", msg.Error.Code) + return state, client + } + type query struct { + XMLName xml.Name `xml:"query"` + } + q := &query{} + err := xml.Unmarshal(msg.Body, q) + + if q.XMLName.Space != messages.NSIQRegister || err != nil { + client.log.Warn("is no iq register: ", err) + return nil, client + } + client.out.Encode(&messages.IQ{ + Type: messages.IQTypeResult, + ID: msg.ID, + Body: []byte(fmt.Sprintf(` + Choose a username and password for use with this service. + + + + `, messages.NSIQRegister)), + }) + return state.Next, client +} + +type RegisterRequest struct { + Next State +} + +// Process message +func (state *RegisterRequest) Process(client *Client) (State, *Client) { + client.log = client.log.WithField("state", "register request") + client.log.Debug("running") + defer client.log.Debug("leave") + + element, err := client.Read() + if err != nil { + client.log.Warn("unable to read: ", err) + return nil, client + } + var msg messages.IQ + if err = client.in.DecodeElement(&msg, element); err != nil { + client.log.Warn("is no iq: ", err) + return state, client + } + if msg.Type != messages.IQTypeGet { + client.log.Warn("is no get iq") + return state, client + } + if msg.Error != nil { + client.log.Warn("iq with error: ", msg.Error.Code) + return state, client + } + type query struct { + XMLName xml.Name `xml:"query"` + Username string `xml:"username"` + Password string `xml:"password"` + } + q := &query{} + err = xml.Unmarshal(msg.Body, q) + if err != nil { + client.log.Warn("is no iq register: ", err) + return nil, client + } + + client.jid.Local = q.Username + client.log = client.log.WithField("jid", client.jid.Full()) + account := model.NewAccount(client.jid, q.Password) + err = client.Server.Database.AddAccount(account) + if err != nil { + client.out.Encode(&messages.IQ{ + Type: messages.IQTypeResult, + ID: msg.ID, + Body: []byte(fmt.Sprintf(` + %s + %s + `, messages.NSIQRegister, q.Username, q.Password)), + Error: &messages.Error{ + Code: "409", + Type: "cancel", + Any: xml.Name{ + Local: "conflict", + Space: "urn:ietf:params:xml:ns:xmpp-stanzas", + }, + }, + }) + client.log.Warn("database error: ", err) + return state, client + } + client.account = account + client.out.Encode(&messages.IQ{ + Type: messages.IQTypeResult, + ID: msg.ID, + }) + + client.log.Infof("registered client %s", client.jid.Bare()) + return state.Next, client +} diff --git a/server/utils.go b/server/utils.go new file mode 100644 index 0000000..561161e --- /dev/null +++ b/server/utils.go @@ -0,0 +1,29 @@ +package server + +import ( + "crypto/rand" + "encoding/binary" + "fmt" +) + +// Cookie is used to give a unique identifier to each request. +type Cookie uint64 + +func createCookie() Cookie { + var buf [8]byte + if _, err := rand.Reader.Read(buf[:]); err != nil { + panic("Failed to read random bytes: " + err.Error()) + } + return Cookie(binary.LittleEndian.Uint64(buf[:])) +} +func createCookieString() string { + return fmt.Sprintf("%x", createCookie()) +} + +func makeResource() string { + var buf [16]byte + if _, err := rand.Reader.Read(buf[:]); err != nil { + panic("Failed to read random bytes: " + err.Error()) + } + return fmt.Sprintf("%x", buf) +}