sum7
/
yaja
Archived
1
0
Fork 0

lets encrypt + registration

This commit is contained in:
Martin Geno 2017-12-14 21:30:07 +01:00
parent 8c60ef89c6
commit 800a5b1917
No known key found for this signature in database
GPG Key ID: F0D39A37E925E941
24 changed files with 899 additions and 162 deletions

View File

@ -1,6 +1,6 @@
workspace: workspace:
base: /go base: /go
path: src/dev.sum7.eu/genofire/yaja path: src/github.com/genofire/yaja
pipeline: pipeline:
build: build:

View File

@ -1,3 +1,24 @@
# yaja # yaja (yet another jabber server)
Yet Another Jabber Server ```
A small standalone jabber server, for easy deployment
Usage:
yaja [command]
Available Commands:
help Help about any command
server Runs the yaja server
Flags:
-h, --help help for yaja
Use "yaja [command] --help" for more information about a command.
```
## Features (works already)
- get certificate by lets encrypt
- registration (for every possible ssl domain)
## Inspiration
- [tam7t](https://github.com/tam7t/xmpp) a fork of [agl](https://github.com/agl)'s work

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"crypto/tls" "crypto/tls"
"net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@ -9,11 +10,12 @@ import (
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"dev.sum7.eu/genofire/yaja/model" "github.com/genofire/yaja/database"
"dev.sum7.eu/genofire/yaja/model/config" "github.com/genofire/yaja/model/config"
"dev.sum7.eu/genofire/yaja/server" "github.com/genofire/golang-lib/file"
"github.com/genofire/golang-lib/worker" "github.com/genofire/golang-lib/worker"
"github.com/genofire/yaja/server"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -22,32 +24,34 @@ import (
var configPath string var configPath string
var ( var (
configData *config.Config configData = &config.Config{}
state *model.State db = &database.State{}
statesaveWorker *worker.Worker statesaveWorker *worker.Worker
srv *server.Server srv *server.Server
certs *tls.Config certs *tls.Config
) )
// serveCmd represents the serve command // serverCmd represents the serve command
var serveCmd = &cobra.Command{ var serverCmd = &cobra.Command{
Use: "serve", Use: "server",
Short: "Runs the yaja server", Short: "Runs the yaja server",
Example: "yaja serve -c /etc/yaja.conf", Example: "yaja serve -c /etc/yaja.conf",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
var err error var err error
configData, err = config.ReadConfigFile(configPath) err = file.ReadTOML(configPath, configData)
if err != nil { if err != nil {
log.Fatal("unable to load config file:", err) log.Fatal("unable to load config file:", err)
} }
state, err = model.ReadState(configData.StatePath) log.SetLevel(log.DebugLevel)
err = file.ReadJSON(configData.StatePath, db)
if err != nil { if err != nil {
log.Warn("unable to load state file:", err) log.Warn("unable to load state file:", err)
} }
statesaveWorker = worker.NewWorker(time.Minute, func() { statesaveWorker = worker.NewWorker(time.Minute, func() {
model.SaveJSON(state, configData.StatePath) file.SaveJSON(configData.StatePath, db)
log.Info("save state to:", configData.StatePath) log.Info("save state to:", configData.StatePath)
}) })
@ -56,13 +60,18 @@ var serveCmd = &cobra.Command{
Prompt: autocert.AcceptTOS, Prompt: autocert.AcceptTOS,
} }
certs = &tls.Config{GetCertificate: m.GetCertificate} // https server to handle acme (by letsencrypt)
httpServer := &http.Server{
Addr: ":https",
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
}
go httpServer.ListenAndServeTLS("", "")
srv = &server.Server{ srv = &server.Server{
TLSConfig: certs, TLSManager: &m,
State: state, Database: db,
PortClient: configData.PortClient, ClientAddr: configData.Address.Client,
PortServer: configData.PortServer, ServerAddr: configData.Address.Server,
} }
go statesaveWorker.Start() go statesaveWorker.Start()
@ -95,21 +104,24 @@ func quit() {
srv.Close() srv.Close()
statesaveWorker.Close() statesaveWorker.Close()
model.SaveJSON(state, configData.StatePath) file.SaveJSON(configData.StatePath, db)
} }
func reload() { func reload() {
log.Info("start reloading...") log.Info("start reloading...")
configNewData, err := config.ReadConfigFile(configPath) var configNewData *config.Config
err := file.ReadTOML(configPath, configNewData)
if err != nil { if err != nil {
log.Warn("unable to load config file:", err) log.Warn("unable to load config file:", err)
return return
} }
//TODO fetch changing address (to set restart)
if configNewData.StatePath != configData.StatePath { if configNewData.StatePath != configData.StatePath {
statesaveWorker.Close() statesaveWorker.Close()
statesaveWorker := worker.NewWorker(time.Minute, func() { statesaveWorker := worker.NewWorker(time.Minute, func() {
model.SaveJSON(state, configNewData.StatePath) file.SaveJSON(configNewData.StatePath, db)
log.Info("save state to:", configNewData.StatePath) log.Info("save state to:", configNewData.StatePath)
}) })
go statesaveWorker.Start() go statesaveWorker.Start()
@ -130,17 +142,11 @@ func reload() {
newServer := &server.Server{ newServer := &server.Server{
TLSConfig: certs, TLSConfig: certs,
State: state, Database: db,
PortClient: configNewData.PortClient, ClientAddr: configNewData.Address.Client,
PortServer: configNewData.PortServer, ServerAddr: configNewData.Address.Server,
} }
if configNewData.PortServer != configData.PortServer {
restartServer = true
}
if configNewData.PortClient != configData.PortClient {
restartServer = true
}
if restartServer { if restartServer {
go srv.Start() go srv.Start()
//TODO should fetch new server error //TODO should fetch new server error
@ -153,6 +159,6 @@ func reload() {
} }
func init() { func init() {
RootCmd.AddCommand(serveCmd) RootCmd.AddCommand(serverCmd)
serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
} }

View File

@ -1,4 +1,6 @@
tlsdir = "/tmp/ssl" tlsdir = "/tmp/ssl"
state_path = "/tmp/yaja.json" state_path = "/tmp/yaja.json"
port_client = 5222
port_server = 5269 [address]
client = [":5222"]
server = [":5269"]

68
database/main.go Normal file
View File

@ -0,0 +1,68 @@
package database
import (
"errors"
"sync"
"github.com/genofire/yaja/model"
log "github.com/sirupsen/logrus"
)
type State struct {
Domains map[string]*model.Domain `json:"domains"`
sync.Mutex
}
func (s *State) AddAccount(a *model.Account) error {
if a.Local == "" {
return errors.New("No localpart exists in account")
}
if d := a.Domain; d != nil {
if d.FQDN == "" {
return errors.New("No fqdn exists in domain")
}
s.Lock()
domain, ok := s.Domains[d.FQDN]
if !ok {
if s.Domains == nil {
s.Domains = make(map[string]*model.Domain)
}
s.Domains[d.FQDN] = d
domain = d
}
s.Unlock()
domain.Lock()
defer domain.Unlock()
if domain.Accounts == nil {
domain.Accounts = make(map[string]*model.Account)
}
_, ok = domain.Accounts[a.Local]
if ok {
return errors.New("exists already")
}
domain.Accounts[a.Local] = a
a.Domain = d
return nil
}
return errors.New("no give domain")
}
func (s *State) Authenticate(jid *model.JID, password string) (bool, error) {
logger := log.WithField("database", "auth")
if domain, ok := s.Domains[jid.Domain]; ok {
if acc, ok := domain.Accounts[jid.Local]; ok {
if acc.ValidatePassword(password) {
return true, nil
} else {
logger.Debug("password not valid")
}
} else {
logger.Debug("account not found")
}
} else {
logger.Debug("domain not found")
}
return false, nil
}

View File

@ -1,6 +1,6 @@
package main package main
import "dev.sum7.eu/genofire/yaja/cmd" import "github.com/genofire/yaja/cmd"
func main() { func main() {
cmd.Execute() cmd.Execute()

12
messages/error.go Normal file
View File

@ -0,0 +1,12 @@
package messages
import "encoding/xml"
// Error element
type Error struct {
XMLName xml.Name `xml:"jabber:client error"`
Code string `xml:"code,attr"`
Type string `xml:"type,attr"`
Any xml.Name `xml:",any"`
Text string `xml:"text"`
}

25
messages/iq.go Normal file
View File

@ -0,0 +1,25 @@
package messages
import "encoding/xml"
type IQType string
const (
IQTypeGet IQType = "get"
IQTypeSet IQType = "set"
IQTypeResult IQType = "result"
IQTypeError IQType = "error"
)
// IQ element - info/query
type IQ struct {
XMLName xml.Name `xml:"jabber:client iq"`
From string `xml:"from,attr"`
ID string `xml:"id,attr"`
To string `xml:"to,attr"`
Type IQType `xml:"type,attr"`
Error *Error `xml:"error"`
//Bind bindBind `xml:"bind"`
Body []byte `xml:",innerxml"`
// RosterRequest - better detection of iq's
}

19
messages/namespaces.go Normal file
View File

@ -0,0 +1,19 @@
package messages
const (
// NSStream stream namesapce
NSStream = "http://etherx.jabber.org/streams"
// NSTLS xmpp-tls xml namespace
NSTLS = "urn:ietf:params:xml:ns:xmpp-tls"
// NSSASL xmpp-sasl xml namespace
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
// NSClient jabbet client namespace
NSClient = "jabber:client"
NSIQRegister = "jabber:iq:register"
NSFeaturesIQRegister = "http://jabber.org/features/iq-register"
)

32
messages/presence.go Normal file
View File

@ -0,0 +1,32 @@
package messages
import "encoding/xml"
type PresenceType string
const (
PresenceTypeUnavailable PresenceType = "unavailable"
PresenceTypeSubscribe PresenceType = "subscribe"
PresenceTypeSubscribed PresenceType = "subscribed"
PresenceTypeUnsubscribe PresenceType = "unsubscribe"
PresenceTypeUnsubscribed PresenceType = "unsubscribed"
PresenceTypeProbe PresenceType = "probe"
PresenceTypeError PresenceType = "error"
)
// Presence element
type Presence struct {
XMLName xml.Name `xml:"jabber:client presence"`
From string `xml:"from,attr,omitempty"`
ID string `xml:"id,attr,omitempty"`
To string `xml:"to,attr,omitempty"`
Type string `xml:"type,attr,omitempty"`
Lang string `xml:"lang,attr,omitempty"`
Show string `xml:"show,omitempty"` // away, chat, dnd, xa
Status string `xml:"status,omitempty"` // sb []clientText
Priority string `xml:"priority,omitempty"`
// Caps *ClientCaps `xml:"c"`
Error *Error `xml:"error"`
// Delay Delay `xml:"delay"`
}

10
messages/sasl.go Normal file
View File

@ -0,0 +1,10 @@
package messages
import "encoding/xml"
// SASLAuth element
type SASLAuth struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl auth"`
Mechanism string `xml:"mechanism,attr"`
Body string `xml:",chardata"`
}

View File

@ -6,8 +6,8 @@ import (
) )
type Domain struct { type Domain struct {
FQDN string FQDN string `json:"-"`
Accounts map[string]*Account Accounts map[string]*Account `json:"users"`
sync.Mutex sync.Mutex
} }
@ -29,8 +29,22 @@ func (d *Domain) UpdateAccount(a *Account) error {
} }
type Account struct { type Account struct {
Local string Local string `json:"-"`
Domain *Domain Domain *Domain `json:"-"`
Password string `json:"password"`
}
func NewAccount(jid *JID, password string) *Account {
if jid == nil {
return nil
}
return &Account{
Local: jid.Local,
Domain: &Domain{
FQDN: jid.Domain,
},
Password: password,
}
} }
func (a *Account) GetJID() *JID { func (a *Account) GetJID() *JID {
@ -39,3 +53,7 @@ func (a *Account) GetJID() *JID {
Local: a.Local, Local: a.Local,
} }
} }
func (a *Account) ValidatePassword(password string) bool {
return a.Password == password
}

View File

@ -1,24 +0,0 @@
package config
import (
"io/ioutil"
"github.com/BurntSushi/toml"
)
// ReadConfigFile reads a config model from path of a yml file
func ReadConfigFile(path string) (config *Config, err error) {
config = &Config{}
file, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
err = toml.Unmarshal(file, config)
if err != nil {
return nil, err
}
return
}

View File

@ -1,25 +0,0 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestReadConfig(t *testing.T) {
assert := assert.New(t)
config, err := ReadConfigFile("../../config_example.conf")
assert.NoError(err)
assert.NotNil(config)
assert.Equal("/tmp/ssl", config.TLSDir)
config, err = ReadConfigFile("../config_example.co")
assert.Nil(config)
assert.Contains(err.Error(), "no such file or directory")
config, err = ReadConfigFile("testdata/config_panic.conf")
assert.Nil(config)
assert.Contains(err.Error(), "keys cannot contain")
}

View File

@ -3,6 +3,8 @@ package config
type Config struct { type Config struct {
TLSDir string `toml:"tlsdir"` TLSDir string `toml:"tlsdir"`
StatePath string `toml:"state_path"` StatePath string `toml:"state_path"`
PortClient int `toml:"port_client"` Address struct {
PortServer int `toml:"port_server"` Client []string `toml:"client"`
Server []string `toml:"server"`
} `toml:"address"`
} }

View File

@ -1,27 +0,0 @@
package model
import (
"encoding/json"
"log"
"os"
)
// SaveJSON to path
func SaveJSON(input interface{}, outputFile string) {
tmpFile := outputFile + ".tmp"
f, err := os.OpenFile(tmpFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
log.Panic(err)
}
err = json.NewEncoder(f).Encode(input)
if err != nil {
log.Panic(err)
}
f.Close()
if err := os.Rename(tmpFile, outputFile); err != nil {
log.Panic(err)
}
}

View File

@ -42,6 +42,8 @@ func (jid *JID) Bare() string {
return jid.Domain return jid.Domain
} }
func (jid *JID) String() string { return jid.Bare() }
// Full get the "full" jid as string // Full get the "full" jid as string
func (jid *JID) Full() string { func (jid *JID) Full() string {
if jid.Resource != "" { if jid.Resource != "" {

View File

@ -1,38 +0,0 @@
package model
import (
"encoding/json"
"errors"
"os"
"sync"
)
type State struct {
Domains map[string]*Domain
sync.Mutex
}
func ReadState(path string) (state *State, err error) {
state = &State{}
if f, err := os.Open(path); err == nil { // transform data to legacy meshviewer
if err = json.NewDecoder(f).Decode(state); err == nil {
return state, nil
} else {
return nil, err
}
} else {
return nil, err
}
}
func (s *State) UpdateDomain(d *Domain) error {
if d.FQDN == "" {
return errors.New("No fqdn exists in domain")
}
s.Lock()
s.Domains[d.FQDN] = d
s.Unlock()
return nil
}

62
server/client.go Normal file
View File

@ -0,0 +1,62 @@
package server
import (
"encoding/xml"
"net"
"github.com/genofire/yaja/model"
log "github.com/sirupsen/logrus"
)
type Client struct {
log *log.Entry
Conn net.Conn
out *xml.Encoder
in *xml.Decoder
Server *Server
jid *model.JID
account *model.Account
messages chan interface{}
close chan interface{}
}
func NewClient(conn net.Conn, srv *Server) *Client {
logger := log.New()
logger.SetLevel(log.DebugLevel)
client := &Client{
Conn: conn,
Server: srv,
log: log.NewEntry(logger),
in: xml.NewDecoder(conn),
out: xml.NewEncoder(conn),
}
return client
}
func (client *Client) NewConnecting(conn net.Conn) {
client.Conn = conn
client.in = xml.NewDecoder(conn)
client.out = xml.NewEncoder(conn)
}
func (client *Client) Read() (*xml.StartElement, error) {
for {
nextToken, err := client.in.Token()
if err != nil {
return nil, err
}
switch nextToken.(type) {
case xml.StartElement:
element := nextToken.(xml.StartElement)
return &element, nil
}
}
}
func (client *Client) Close() {
client.close <- true
client.Conn.Close()
}

View File

@ -2,19 +2,82 @@ package server
import ( import (
"crypto/tls" "crypto/tls"
"net"
"dev.sum7.eu/genofire/yaja/model" "github.com/genofire/yaja/database"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
) )
type Server struct { type Server struct {
TLSConfig *tls.Config TLSConfig *tls.Config
PortClient int TLSManager *autocert.Manager
PortServer int ClientAddr []string
State *model.State ServerAddr []string
Database *database.State
} }
func (srv *Server) Start() { func (srv *Server) Start() {
for _, addr := range srv.ServerAddr {
socket, err := net.Listen("tcp", addr)
if err != nil {
log.Warn("create server socket: ", err.Error())
break
}
go srv.listenServer(socket)
}
for _, addr := range srv.ClientAddr {
socket, err := net.Listen("tcp", addr)
if err != nil {
log.Warn("create client socket: ", err.Error())
break
}
go srv.listenClient(socket)
}
}
func (srv *Server) listenServer(s2s net.Listener) {
for {
conn, err := s2s.Accept()
if err != nil {
log.Warn("accepting server connection: ", err.Error())
break
}
go srv.handleServer(conn)
}
}
func (srv *Server) listenClient(c2s net.Listener) {
for {
conn, err := c2s.Accept()
if err != nil {
log.Warn("accepting client connection: ", err.Error())
break
}
go srv.handleClient(conn)
}
}
func (srv *Server) handleServer(conn net.Conn) {
log.Info("new server connection:", conn.RemoteAddr())
}
func (srv *Server) handleClient(conn net.Conn) {
log.Info("new client connection:", conn.RemoteAddr())
client := NewClient(conn, srv)
state := ConnectionStartup()
for {
state, client = state.Process(client)
if state == nil {
client.log.Info("disconnect")
client.Close()
//s.DisconnectBus <- Disconnect{Jid: client.jid}
return
}
// run next state
}
} }
func (srv *Server) Close() { func (srv *Server) Close() {

6
server/state.go Normal file
View File

@ -0,0 +1,6 @@
package server
// State processes the stream and moves to the next state
type State interface {
Process(client *Client) (State, *Client)
}

344
server/state_connect.go Normal file
View File

@ -0,0 +1,344 @@
package server
import (
"crypto/tls"
"encoding/base64"
"encoding/xml"
"fmt"
"strings"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/model"
)
// ConnectionStartup return steps through TCP TLS state
func ConnectionStartup() State {
receiving := &ReceivingClient{}
sending := &SendingClient{Next: receiving}
authedstream := &AuthedStream{Next: sending}
authedstart := &AuthedStart{Next: authedstream}
tlsauth := &SASLAuth{Next: authedstart}
tlsstream := &TLSStream{Next: tlsauth}
tlsupgrade := &TLSUpgrade{Next: tlsstream}
stream := &Start{Next: tlsupgrade}
return stream
}
// Start state
type Start struct {
Next State
}
// Process message
func (state *Start) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "stream")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.log.Warn("is no stream")
return state, client
}
for _, attr := range element.Attr {
if attr.Name.Local == "to" {
client.jid = &model.JID{Domain: attr.Value}
client.log = client.log.WithField("jid", client.jid.Full())
}
}
if client.jid == nil {
client.log.Warn("no 'to' domain readed")
return nil, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
createCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<starttls xmlns='%s'>
<required/>
</starttls>
</stream:features>`,
messages.NSStream)
return state.Next, client
}
// TLSUpgrade state
type TLSUpgrade struct {
Next State
}
// Process message
func (state *TLSUpgrade) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "tls upgrade")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
client.log.Warn("is no starttls")
return state, client
}
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
// perform the TLS handshake
var tlsConfig *tls.Config
if m := client.Server.TLSManager; m != nil {
var cert *tls.Certificate
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain})
if err != nil {
client.log.Warn("no cert in tls manger found: ", err)
return nil, client
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{*cert},
}
}
if tlsConfig == nil {
tlsConfig = client.Server.TLSConfig
if tlsConfig != nil {
tlsConfig.ServerName = client.jid.Domain
} else {
client.log.Warn("no tls config found: ", err)
return nil, client
}
}
tlsConn := tls.Server(client.Conn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
client.log.Warn("unable to tls handshake: ", err)
return nil, client
}
// restart the Connection
client.NewConnecting(tlsConn)
return state.Next, client
}
// TLSStream state
type TLSStream struct {
Next State
}
// Process messages
func (state *TLSStream) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "tls stream")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.log.Warn("is no stream")
return state, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
createCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
<register xmlns='%s'/>
</stream:features>`,
messages.NSSASL, messages.NSFeaturesIQRegister)
return state.Next, client
}
// SASLAuth state
type SASLAuth struct {
Next State
}
// Process messages
func (state *SASLAuth) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "sasl auth")
client.log.Debug("running")
defer client.log.Debug("leave")
// read the full auth stanza
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
var auth messages.SASLAuth
if err = client.in.DecodeElement(&auth, element); err != nil {
client.log.Info("start substate for registration")
return &RegisterFormRequest{
element: element,
Next: &RegisterRequest{
Next: state.Next,
},
}, client
}
data, err := base64.StdEncoding.DecodeString(auth.Body)
if err != nil {
client.log.Warn("body decode: ", err)
return nil, client
}
info := strings.Split(string(data), "\x00")
// should check that info[1] starts with client.jid
client.jid.Local = info[1]
client.log = client.log.WithField("jid", client.jid.Full())
success, err := client.Server.Database.Authenticate(client.jid, info[2])
if err != nil {
client.log.Warn("auth: ", err)
return nil, client
}
if success {
client.log.Info("success auth")
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
return state.Next, client
}
client.log.Warn("failed auth")
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
return nil, client
}
// AuthedStart state
type AuthedStart struct {
Next State
}
// Process messages
func (state *AuthedStart) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "authed started")
client.log.Debug("running")
defer client.log.Debug("leave")
_, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
createCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<bind xmlns='%s'/>
</stream:features>`,
messages.NSBind)
return state.Next, client
}
// AuthedStream state
type AuthedStream struct {
Next State
}
// Process messages
func (state *AuthedStream) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "authed stream")
client.log.Debug("running")
defer client.log.Debug("leave")
// check that it's a bind request
// read bind request
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
var msg messages.IQ
if err = client.in.DecodeElement(&msg, element); err != nil {
client.log.Warn("is no iq: ", err)
return nil, client
}
if msg.Type != messages.IQTypeSet {
client.log.Warn("is no set iq")
return nil, client
}
if msg.Error != nil {
client.log.Warn("iq with error: ", msg.Error.Code)
return nil, client
}
type query struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
Resource string `xml:"resource"`
}
q := &query{}
err = xml.Unmarshal(msg.Body, q)
if err != nil {
client.log.Warn("is no iq bind: ", err)
return nil, client
}
if q.Resource == "" {
client.jid.Resource = makeResource()
} else {
client.jid.Resource = q.Resource
}
client.log = client.log.WithField("jid", client.jid.Full())
client.out.Encode(&messages.IQ{
Type: messages.IQTypeResult,
ID: msg.ID,
Body: []byte(fmt.Sprintf(
`<bind xmlns='%s'>
<jid>%s</jid>
</bind>`,
messages.NSBind, client.jid.Full())),
})
return state.Next, client
}
// SendingClient state
type SendingClient struct {
Next State
}
// Process messages
func (state *SendingClient) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "normal")
client.log.Debug("sending")
// sending
go func() {
select {
case msg := <-client.messages:
err := client.out.Encode(msg)
client.log.Info(err)
case <-client.close:
return
}
}()
client.log.Debug("receiving")
return state.Next, client
}
// ReceivingClient state
type ReceivingClient struct {
}
// Process messages
func (state *ReceivingClient) Process(client *Client) (State, *Client) {
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
/*
for _, extension := range client.Server.Extensions {
extension.Process(element, client)
}*/
client.log.Debug(element)
return state, client
}

130
server/state_register.go Normal file
View File

@ -0,0 +1,130 @@
package server
import (
"encoding/xml"
"fmt"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/model"
)
type RegisterFormRequest struct {
Next State
element *xml.StartElement
}
// Process message
func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "register form request")
client.log.Debug("running")
defer client.log.Debug("leave")
var msg messages.IQ
if err := client.in.DecodeElement(&msg, state.element); err != nil {
client.log.Warn("is no iq: ", err)
return state, client
}
if msg.Type != messages.IQTypeGet {
client.log.Warn("is no get iq")
return state, client
}
if msg.Error != nil {
client.log.Warn("iq with error: ", msg.Error.Code)
return state, client
}
type query struct {
XMLName xml.Name `xml:"query"`
}
q := &query{}
err := xml.Unmarshal(msg.Body, q)
if q.XMLName.Space != messages.NSIQRegister || err != nil {
client.log.Warn("is no iq register: ", err)
return nil, client
}
client.out.Encode(&messages.IQ{
Type: messages.IQTypeResult,
ID: msg.ID,
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
Choose a username and password for use with this service.
</instructions>
<username/>
<password/>
</query>`, messages.NSIQRegister)),
})
return state.Next, client
}
type RegisterRequest struct {
Next State
}
// Process message
func (state *RegisterRequest) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "register request")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
var msg messages.IQ
if err = client.in.DecodeElement(&msg, element); err != nil {
client.log.Warn("is no iq: ", err)
return state, client
}
if msg.Type != messages.IQTypeGet {
client.log.Warn("is no get iq")
return state, client
}
if msg.Error != nil {
client.log.Warn("iq with error: ", msg.Error.Code)
return state, client
}
type query struct {
XMLName xml.Name `xml:"query"`
Username string `xml:"username"`
Password string `xml:"password"`
}
q := &query{}
err = xml.Unmarshal(msg.Body, q)
if err != nil {
client.log.Warn("is no iq register: ", err)
return nil, client
}
client.jid.Local = q.Username
client.log = client.log.WithField("jid", client.jid.Full())
account := model.NewAccount(client.jid, q.Password)
err = client.Server.Database.AddAccount(account)
if err != nil {
client.out.Encode(&messages.IQ{
Type: messages.IQTypeResult,
ID: msg.ID,
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
<username>%s</username>
<password>%s</password>
</query>`, messages.NSIQRegister, q.Username, q.Password)),
Error: &messages.Error{
Code: "409",
Type: "cancel",
Any: xml.Name{
Local: "conflict",
Space: "urn:ietf:params:xml:ns:xmpp-stanzas",
},
},
})
client.log.Warn("database error: ", err)
return state, client
}
client.account = account
client.out.Encode(&messages.IQ{
Type: messages.IQTypeResult,
ID: msg.ID,
})
client.log.Infof("registered client %s", client.jid.Bare())
return state.Next, client
}

29
server/utils.go Normal file
View File

@ -0,0 +1,29 @@
package server
import (
"crypto/rand"
"encoding/binary"
"fmt"
)
// Cookie is used to give a unique identifier to each request.
type Cookie uint64
func createCookie() Cookie {
var buf [8]byte
if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error())
}
return Cookie(binary.LittleEndian.Uint64(buf[:]))
}
func createCookieString() string {
return fmt.Sprintf("%x", createCookie())
}
func makeResource() string {
var buf [16]byte
if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error())
}
return fmt.Sprintf("%x", buf)
}