lets encrypt + registration
This commit is contained in:
parent
8c60ef89c6
commit
800a5b1917
|
@ -1,6 +1,6 @@
|
||||||
workspace:
|
workspace:
|
||||||
base: /go
|
base: /go
|
||||||
path: src/dev.sum7.eu/genofire/yaja
|
path: src/github.com/genofire/yaja
|
||||||
|
|
||||||
pipeline:
|
pipeline:
|
||||||
build:
|
build:
|
||||||
|
|
25
README.md
25
README.md
|
@ -1,3 +1,24 @@
|
||||||
# yaja
|
# yaja (yet another jabber server)
|
||||||
|
|
||||||
Yet Another Jabber Server
|
```
|
||||||
|
A small standalone jabber server, for easy deployment
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
yaja [command]
|
||||||
|
|
||||||
|
Available Commands:
|
||||||
|
help Help about any command
|
||||||
|
server Runs the yaja server
|
||||||
|
|
||||||
|
Flags:
|
||||||
|
-h, --help help for yaja
|
||||||
|
|
||||||
|
Use "yaja [command] --help" for more information about a command.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Features (works already)
|
||||||
|
- get certificate by lets encrypt
|
||||||
|
- registration (for every possible ssl domain)
|
||||||
|
|
||||||
|
## Inspiration
|
||||||
|
- [tam7t](https://github.com/tam7t/xmpp) a fork of [agl](https://github.com/agl)'s work
|
||||||
|
|
|
@ -2,6 +2,7 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -9,11 +10,12 @@ import (
|
||||||
|
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
|
||||||
"dev.sum7.eu/genofire/yaja/model"
|
"github.com/genofire/yaja/database"
|
||||||
"dev.sum7.eu/genofire/yaja/model/config"
|
"github.com/genofire/yaja/model/config"
|
||||||
|
|
||||||
"dev.sum7.eu/genofire/yaja/server"
|
"github.com/genofire/golang-lib/file"
|
||||||
"github.com/genofire/golang-lib/worker"
|
"github.com/genofire/golang-lib/worker"
|
||||||
|
"github.com/genofire/yaja/server"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
@ -22,32 +24,34 @@ import (
|
||||||
var configPath string
|
var configPath string
|
||||||
|
|
||||||
var (
|
var (
|
||||||
configData *config.Config
|
configData = &config.Config{}
|
||||||
state *model.State
|
db = &database.State{}
|
||||||
statesaveWorker *worker.Worker
|
statesaveWorker *worker.Worker
|
||||||
srv *server.Server
|
srv *server.Server
|
||||||
certs *tls.Config
|
certs *tls.Config
|
||||||
)
|
)
|
||||||
|
|
||||||
// serveCmd represents the serve command
|
// serverCmd represents the serve command
|
||||||
var serveCmd = &cobra.Command{
|
var serverCmd = &cobra.Command{
|
||||||
Use: "serve",
|
Use: "server",
|
||||||
Short: "Runs the yaja server",
|
Short: "Runs the yaja server",
|
||||||
Example: "yaja serve -c /etc/yaja.conf",
|
Example: "yaja serve -c /etc/yaja.conf",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
var err error
|
var err error
|
||||||
configData, err = config.ReadConfigFile(configPath)
|
err = file.ReadTOML(configPath, configData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("unable to load config file:", err)
|
log.Fatal("unable to load config file:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
state, err = model.ReadState(configData.StatePath)
|
log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
|
err = file.ReadJSON(configData.StatePath, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("unable to load state file:", err)
|
log.Warn("unable to load state file:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statesaveWorker = worker.NewWorker(time.Minute, func() {
|
statesaveWorker = worker.NewWorker(time.Minute, func() {
|
||||||
model.SaveJSON(state, configData.StatePath)
|
file.SaveJSON(configData.StatePath, db)
|
||||||
log.Info("save state to:", configData.StatePath)
|
log.Info("save state to:", configData.StatePath)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -56,13 +60,18 @@ var serveCmd = &cobra.Command{
|
||||||
Prompt: autocert.AcceptTOS,
|
Prompt: autocert.AcceptTOS,
|
||||||
}
|
}
|
||||||
|
|
||||||
certs = &tls.Config{GetCertificate: m.GetCertificate}
|
// https server to handle acme (by letsencrypt)
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: ":https",
|
||||||
|
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
|
||||||
|
}
|
||||||
|
go httpServer.ListenAndServeTLS("", "")
|
||||||
|
|
||||||
srv = &server.Server{
|
srv = &server.Server{
|
||||||
TLSConfig: certs,
|
TLSManager: &m,
|
||||||
State: state,
|
Database: db,
|
||||||
PortClient: configData.PortClient,
|
ClientAddr: configData.Address.Client,
|
||||||
PortServer: configData.PortServer,
|
ServerAddr: configData.Address.Server,
|
||||||
}
|
}
|
||||||
|
|
||||||
go statesaveWorker.Start()
|
go statesaveWorker.Start()
|
||||||
|
@ -95,21 +104,24 @@ func quit() {
|
||||||
srv.Close()
|
srv.Close()
|
||||||
statesaveWorker.Close()
|
statesaveWorker.Close()
|
||||||
|
|
||||||
model.SaveJSON(state, configData.StatePath)
|
file.SaveJSON(configData.StatePath, db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func reload() {
|
func reload() {
|
||||||
log.Info("start reloading...")
|
log.Info("start reloading...")
|
||||||
configNewData, err := config.ReadConfigFile(configPath)
|
var configNewData *config.Config
|
||||||
|
err := file.ReadTOML(configPath, configNewData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("unable to load config file:", err)
|
log.Warn("unable to load config file:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO fetch changing address (to set restart)
|
||||||
|
|
||||||
if configNewData.StatePath != configData.StatePath {
|
if configNewData.StatePath != configData.StatePath {
|
||||||
statesaveWorker.Close()
|
statesaveWorker.Close()
|
||||||
statesaveWorker := worker.NewWorker(time.Minute, func() {
|
statesaveWorker := worker.NewWorker(time.Minute, func() {
|
||||||
model.SaveJSON(state, configNewData.StatePath)
|
file.SaveJSON(configNewData.StatePath, db)
|
||||||
log.Info("save state to:", configNewData.StatePath)
|
log.Info("save state to:", configNewData.StatePath)
|
||||||
})
|
})
|
||||||
go statesaveWorker.Start()
|
go statesaveWorker.Start()
|
||||||
|
@ -130,17 +142,11 @@ func reload() {
|
||||||
|
|
||||||
newServer := &server.Server{
|
newServer := &server.Server{
|
||||||
TLSConfig: certs,
|
TLSConfig: certs,
|
||||||
State: state,
|
Database: db,
|
||||||
PortClient: configNewData.PortClient,
|
ClientAddr: configNewData.Address.Client,
|
||||||
PortServer: configNewData.PortServer,
|
ServerAddr: configNewData.Address.Server,
|
||||||
}
|
}
|
||||||
|
|
||||||
if configNewData.PortServer != configData.PortServer {
|
|
||||||
restartServer = true
|
|
||||||
}
|
|
||||||
if configNewData.PortClient != configData.PortClient {
|
|
||||||
restartServer = true
|
|
||||||
}
|
|
||||||
if restartServer {
|
if restartServer {
|
||||||
go srv.Start()
|
go srv.Start()
|
||||||
//TODO should fetch new server error
|
//TODO should fetch new server error
|
||||||
|
@ -153,6 +159,6 @@ func reload() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RootCmd.AddCommand(serveCmd)
|
RootCmd.AddCommand(serverCmd)
|
||||||
serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
||||||
}
|
}
|
|
@ -1,4 +1,6 @@
|
||||||
tlsdir = "/tmp/ssl"
|
tlsdir = "/tmp/ssl"
|
||||||
state_path = "/tmp/yaja.json"
|
state_path = "/tmp/yaja.json"
|
||||||
port_client = 5222
|
|
||||||
port_server = 5269
|
[address]
|
||||||
|
client = [":5222"]
|
||||||
|
server = [":5269"]
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
Domains map[string]*model.Domain `json:"domains"`
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) AddAccount(a *model.Account) error {
|
||||||
|
if a.Local == "" {
|
||||||
|
return errors.New("No localpart exists in account")
|
||||||
|
}
|
||||||
|
if d := a.Domain; d != nil {
|
||||||
|
if d.FQDN == "" {
|
||||||
|
return errors.New("No fqdn exists in domain")
|
||||||
|
}
|
||||||
|
s.Lock()
|
||||||
|
domain, ok := s.Domains[d.FQDN]
|
||||||
|
if !ok {
|
||||||
|
if s.Domains == nil {
|
||||||
|
s.Domains = make(map[string]*model.Domain)
|
||||||
|
}
|
||||||
|
s.Domains[d.FQDN] = d
|
||||||
|
domain = d
|
||||||
|
}
|
||||||
|
s.Unlock()
|
||||||
|
|
||||||
|
domain.Lock()
|
||||||
|
defer domain.Unlock()
|
||||||
|
if domain.Accounts == nil {
|
||||||
|
domain.Accounts = make(map[string]*model.Account)
|
||||||
|
}
|
||||||
|
_, ok = domain.Accounts[a.Local]
|
||||||
|
if ok {
|
||||||
|
return errors.New("exists already")
|
||||||
|
}
|
||||||
|
domain.Accounts[a.Local] = a
|
||||||
|
a.Domain = d
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("no give domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) Authenticate(jid *model.JID, password string) (bool, error) {
|
||||||
|
logger := log.WithField("database", "auth")
|
||||||
|
|
||||||
|
if domain, ok := s.Domains[jid.Domain]; ok {
|
||||||
|
if acc, ok := domain.Accounts[jid.Local]; ok {
|
||||||
|
if acc.ValidatePassword(password) {
|
||||||
|
return true, nil
|
||||||
|
} else {
|
||||||
|
logger.Debug("password not valid")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("account not found")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("domain not found")
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
2
main.go
2
main.go
|
@ -1,6 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "dev.sum7.eu/genofire/yaja/cmd"
|
import "github.com/genofire/yaja/cmd"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cmd.Execute()
|
cmd.Execute()
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import "encoding/xml"
|
||||||
|
|
||||||
|
// Error element
|
||||||
|
type Error struct {
|
||||||
|
XMLName xml.Name `xml:"jabber:client error"`
|
||||||
|
Code string `xml:"code,attr"`
|
||||||
|
Type string `xml:"type,attr"`
|
||||||
|
Any xml.Name `xml:",any"`
|
||||||
|
Text string `xml:"text"`
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import "encoding/xml"
|
||||||
|
|
||||||
|
type IQType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
IQTypeGet IQType = "get"
|
||||||
|
IQTypeSet IQType = "set"
|
||||||
|
IQTypeResult IQType = "result"
|
||||||
|
IQTypeError IQType = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IQ element - info/query
|
||||||
|
type IQ struct {
|
||||||
|
XMLName xml.Name `xml:"jabber:client iq"`
|
||||||
|
From string `xml:"from,attr"`
|
||||||
|
ID string `xml:"id,attr"`
|
||||||
|
To string `xml:"to,attr"`
|
||||||
|
Type IQType `xml:"type,attr"`
|
||||||
|
Error *Error `xml:"error"`
|
||||||
|
//Bind bindBind `xml:"bind"`
|
||||||
|
Body []byte `xml:",innerxml"`
|
||||||
|
// RosterRequest - better detection of iq's
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
package messages
|
||||||
|
|
||||||
|
const (
|
||||||
|
// NSStream stream namesapce
|
||||||
|
NSStream = "http://etherx.jabber.org/streams"
|
||||||
|
// NSTLS xmpp-tls xml namespace
|
||||||
|
NSTLS = "urn:ietf:params:xml:ns:xmpp-tls"
|
||||||
|
// NSSASL xmpp-sasl xml namespace
|
||||||
|
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
||||||
|
|
||||||
|
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
|
||||||
|
|
||||||
|
// NSClient jabbet client namespace
|
||||||
|
NSClient = "jabber:client"
|
||||||
|
|
||||||
|
NSIQRegister = "jabber:iq:register"
|
||||||
|
|
||||||
|
NSFeaturesIQRegister = "http://jabber.org/features/iq-register"
|
||||||
|
)
|
|
@ -0,0 +1,32 @@
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import "encoding/xml"
|
||||||
|
|
||||||
|
type PresenceType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PresenceTypeUnavailable PresenceType = "unavailable"
|
||||||
|
PresenceTypeSubscribe PresenceType = "subscribe"
|
||||||
|
PresenceTypeSubscribed PresenceType = "subscribed"
|
||||||
|
PresenceTypeUnsubscribe PresenceType = "unsubscribe"
|
||||||
|
PresenceTypeUnsubscribed PresenceType = "unsubscribed"
|
||||||
|
PresenceTypeProbe PresenceType = "probe"
|
||||||
|
PresenceTypeError PresenceType = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Presence element
|
||||||
|
type Presence struct {
|
||||||
|
XMLName xml.Name `xml:"jabber:client presence"`
|
||||||
|
From string `xml:"from,attr,omitempty"`
|
||||||
|
ID string `xml:"id,attr,omitempty"`
|
||||||
|
To string `xml:"to,attr,omitempty"`
|
||||||
|
Type string `xml:"type,attr,omitempty"`
|
||||||
|
Lang string `xml:"lang,attr,omitempty"`
|
||||||
|
|
||||||
|
Show string `xml:"show,omitempty"` // away, chat, dnd, xa
|
||||||
|
Status string `xml:"status,omitempty"` // sb []clientText
|
||||||
|
Priority string `xml:"priority,omitempty"`
|
||||||
|
// Caps *ClientCaps `xml:"c"`
|
||||||
|
Error *Error `xml:"error"`
|
||||||
|
// Delay Delay `xml:"delay"`
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
package messages
|
||||||
|
|
||||||
|
import "encoding/xml"
|
||||||
|
|
||||||
|
// SASLAuth element
|
||||||
|
type SASLAuth struct {
|
||||||
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl auth"`
|
||||||
|
Mechanism string `xml:"mechanism,attr"`
|
||||||
|
Body string `xml:",chardata"`
|
||||||
|
}
|
|
@ -6,8 +6,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Domain struct {
|
type Domain struct {
|
||||||
FQDN string
|
FQDN string `json:"-"`
|
||||||
Accounts map[string]*Account
|
Accounts map[string]*Account `json:"users"`
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,8 +29,22 @@ func (d *Domain) UpdateAccount(a *Account) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account struct {
|
type Account struct {
|
||||||
Local string
|
Local string `json:"-"`
|
||||||
Domain *Domain
|
Domain *Domain `json:"-"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAccount(jid *JID, password string) *Account {
|
||||||
|
if jid == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &Account{
|
||||||
|
Local: jid.Local,
|
||||||
|
Domain: &Domain{
|
||||||
|
FQDN: jid.Domain,
|
||||||
|
},
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) GetJID() *JID {
|
func (a *Account) GetJID() *JID {
|
||||||
|
@ -39,3 +53,7 @@ func (a *Account) GetJID() *JID {
|
||||||
Local: a.Local,
|
Local: a.Local,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) ValidatePassword(password string) bool {
|
||||||
|
return a.Password == password
|
||||||
|
}
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io/ioutil"
|
|
||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ReadConfigFile reads a config model from path of a yml file
|
|
||||||
func ReadConfigFile(path string) (config *Config, err error) {
|
|
||||||
config = &Config{}
|
|
||||||
|
|
||||||
file, err := ioutil.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = toml.Unmarshal(file, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReadConfig(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
config, err := ReadConfigFile("../../config_example.conf")
|
|
||||||
assert.NoError(err)
|
|
||||||
assert.NotNil(config)
|
|
||||||
|
|
||||||
assert.Equal("/tmp/ssl", config.TLSDir)
|
|
||||||
|
|
||||||
config, err = ReadConfigFile("../config_example.co")
|
|
||||||
assert.Nil(config)
|
|
||||||
assert.Contains(err.Error(), "no such file or directory")
|
|
||||||
|
|
||||||
config, err = ReadConfigFile("testdata/config_panic.conf")
|
|
||||||
assert.Nil(config)
|
|
||||||
assert.Contains(err.Error(), "keys cannot contain")
|
|
||||||
}
|
|
|
@ -3,6 +3,8 @@ package config
|
||||||
type Config struct {
|
type Config struct {
|
||||||
TLSDir string `toml:"tlsdir"`
|
TLSDir string `toml:"tlsdir"`
|
||||||
StatePath string `toml:"state_path"`
|
StatePath string `toml:"state_path"`
|
||||||
PortClient int `toml:"port_client"`
|
Address struct {
|
||||||
PortServer int `toml:"port_server"`
|
Client []string `toml:"client"`
|
||||||
|
Server []string `toml:"server"`
|
||||||
|
} `toml:"address"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SaveJSON to path
|
|
||||||
func SaveJSON(input interface{}, outputFile string) {
|
|
||||||
tmpFile := outputFile + ".tmp"
|
|
||||||
|
|
||||||
f, err := os.OpenFile(tmpFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = json.NewEncoder(f).Encode(input)
|
|
||||||
if err != nil {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.Close()
|
|
||||||
if err := os.Rename(tmpFile, outputFile); err != nil {
|
|
||||||
log.Panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -42,6 +42,8 @@ func (jid *JID) Bare() string {
|
||||||
return jid.Domain
|
return jid.Domain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (jid *JID) String() string { return jid.Bare() }
|
||||||
|
|
||||||
// Full get the "full" jid as string
|
// Full get the "full" jid as string
|
||||||
func (jid *JID) Full() string {
|
func (jid *JID) Full() string {
|
||||||
if jid.Resource != "" {
|
if jid.Resource != "" {
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type State struct {
|
|
||||||
Domains map[string]*Domain
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadState(path string) (state *State, err error) {
|
|
||||||
state = &State{}
|
|
||||||
|
|
||||||
if f, err := os.Open(path); err == nil { // transform data to legacy meshviewer
|
|
||||||
if err = json.NewDecoder(f).Decode(state); err == nil {
|
|
||||||
return state, nil
|
|
||||||
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *State) UpdateDomain(d *Domain) error {
|
|
||||||
if d.FQDN == "" {
|
|
||||||
return errors.New("No fqdn exists in domain")
|
|
||||||
}
|
|
||||||
s.Lock()
|
|
||||||
s.Domains[d.FQDN] = d
|
|
||||||
s.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
log *log.Entry
|
||||||
|
|
||||||
|
Conn net.Conn
|
||||||
|
out *xml.Encoder
|
||||||
|
in *xml.Decoder
|
||||||
|
|
||||||
|
Server *Server
|
||||||
|
jid *model.JID
|
||||||
|
account *model.Account
|
||||||
|
|
||||||
|
messages chan interface{}
|
||||||
|
close chan interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(conn net.Conn, srv *Server) *Client {
|
||||||
|
logger := log.New()
|
||||||
|
logger.SetLevel(log.DebugLevel)
|
||||||
|
client := &Client{
|
||||||
|
Conn: conn,
|
||||||
|
Server: srv,
|
||||||
|
log: log.NewEntry(logger),
|
||||||
|
in: xml.NewDecoder(conn),
|
||||||
|
out: xml.NewEncoder(conn),
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) NewConnecting(conn net.Conn) {
|
||||||
|
client.Conn = conn
|
||||||
|
client.in = xml.NewDecoder(conn)
|
||||||
|
client.out = xml.NewEncoder(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Read() (*xml.StartElement, error) {
|
||||||
|
for {
|
||||||
|
nextToken, err := client.in.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch nextToken.(type) {
|
||||||
|
case xml.StartElement:
|
||||||
|
element := nextToken.(xml.StartElement)
|
||||||
|
return &element, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Close() {
|
||||||
|
client.close <- true
|
||||||
|
client.Conn.Close()
|
||||||
|
}
|
|
@ -2,19 +2,82 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
|
||||||
"dev.sum7.eu/genofire/yaja/model"
|
"github.com/genofire/yaja/database"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
PortClient int
|
TLSManager *autocert.Manager
|
||||||
PortServer int
|
ClientAddr []string
|
||||||
State *model.State
|
ServerAddr []string
|
||||||
|
Database *database.State
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Start() {
|
func (srv *Server) Start() {
|
||||||
|
for _, addr := range srv.ServerAddr {
|
||||||
|
socket, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("create server socket: ", err.Error())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
go srv.listenServer(socket)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range srv.ClientAddr {
|
||||||
|
socket, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("create client socket: ", err.Error())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
go srv.listenClient(socket)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) listenServer(s2s net.Listener) {
|
||||||
|
for {
|
||||||
|
conn, err := s2s.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("accepting server connection: ", err.Error())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
go srv.handleServer(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) listenClient(c2s net.Listener) {
|
||||||
|
for {
|
||||||
|
conn, err := c2s.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("accepting client connection: ", err.Error())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
go srv.handleClient(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) handleServer(conn net.Conn) {
|
||||||
|
log.Info("new server connection:", conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) handleClient(conn net.Conn) {
|
||||||
|
log.Info("new client connection:", conn.RemoteAddr())
|
||||||
|
client := NewClient(conn, srv)
|
||||||
|
state := ConnectionStartup()
|
||||||
|
|
||||||
|
for {
|
||||||
|
state, client = state.Process(client)
|
||||||
|
if state == nil {
|
||||||
|
client.log.Info("disconnect")
|
||||||
|
client.Close()
|
||||||
|
//s.DisconnectBus <- Disconnect{Jid: client.jid}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// run next state
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Close() {
|
func (srv *Server) Close() {
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
// State processes the stream and moves to the next state
|
||||||
|
type State interface {
|
||||||
|
Process(client *Client) (State, *Client)
|
||||||
|
}
|
|
@ -0,0 +1,344 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/messages"
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionStartup return steps through TCP TLS state
|
||||||
|
func ConnectionStartup() State {
|
||||||
|
receiving := &ReceivingClient{}
|
||||||
|
sending := &SendingClient{Next: receiving}
|
||||||
|
authedstream := &AuthedStream{Next: sending}
|
||||||
|
authedstart := &AuthedStart{Next: authedstream}
|
||||||
|
tlsauth := &SASLAuth{Next: authedstart}
|
||||||
|
tlsstream := &TLSStream{Next: tlsauth}
|
||||||
|
tlsupgrade := &TLSUpgrade{Next: tlsstream}
|
||||||
|
stream := &Start{Next: tlsupgrade}
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start state
|
||||||
|
type Start struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *Start) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "stream")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
|
client.log.Warn("is no stream")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
for _, attr := range element.Attr {
|
||||||
|
if attr.Name.Local == "to" {
|
||||||
|
client.jid = &model.JID{Domain: attr.Value}
|
||||||
|
client.log = client.log.WithField("jid", client.jid.Full())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if client.jid == nil {
|
||||||
|
client.log.Warn("no 'to' domain readed")
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
createCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<starttls xmlns='%s'>
|
||||||
|
<required/>
|
||||||
|
</starttls>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSStream)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSUpgrade state
|
||||||
|
type TLSUpgrade struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *TLSUpgrade) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "tls upgrade")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
|
||||||
|
client.log.Warn("is no starttls")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
|
||||||
|
// perform the TLS handshake
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
if m := client.Server.TLSManager; m != nil {
|
||||||
|
var cert *tls.Certificate
|
||||||
|
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain})
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("no cert in tls manger found: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = client.Server.TLSConfig
|
||||||
|
if tlsConfig != nil {
|
||||||
|
tlsConfig.ServerName = client.jid.Domain
|
||||||
|
} else {
|
||||||
|
client.log.Warn("no tls config found: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Server(client.Conn, tlsConfig)
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to tls handshake: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
// restart the Connection
|
||||||
|
client.NewConnecting(tlsConn)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSStream state
|
||||||
|
type TLSStream struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *TLSStream) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "tls stream")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
|
client.log.Warn("is no stream")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
createCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<mechanisms xmlns='%s'>
|
||||||
|
<mechanism>PLAIN</mechanism>
|
||||||
|
</mechanisms>
|
||||||
|
<register xmlns='%s'/>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSSASL, messages.NSFeaturesIQRegister)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// SASLAuth state
|
||||||
|
type SASLAuth struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *SASLAuth) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "sasl auth")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
// read the full auth stanza
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
var auth messages.SASLAuth
|
||||||
|
if err = client.in.DecodeElement(&auth, element); err != nil {
|
||||||
|
client.log.Info("start substate for registration")
|
||||||
|
return &RegisterFormRequest{
|
||||||
|
element: element,
|
||||||
|
Next: &RegisterRequest{
|
||||||
|
Next: state.Next,
|
||||||
|
},
|
||||||
|
}, client
|
||||||
|
}
|
||||||
|
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("body decode: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
info := strings.Split(string(data), "\x00")
|
||||||
|
// should check that info[1] starts with client.jid
|
||||||
|
client.jid.Local = info[1]
|
||||||
|
client.log = client.log.WithField("jid", client.jid.Full())
|
||||||
|
success, err := client.Server.Database.Authenticate(client.jid, info[2])
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("auth: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if success {
|
||||||
|
client.log.Info("success auth")
|
||||||
|
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
client.log.Warn("failed auth")
|
||||||
|
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
|
||||||
|
return nil, client
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthedStart state
|
||||||
|
type AuthedStart struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *AuthedStart) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "authed started")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
_, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
createCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<bind xmlns='%s'/>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSBind)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthedStream state
|
||||||
|
type AuthedStream struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *AuthedStream) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "authed stream")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
// check that it's a bind request
|
||||||
|
// read bind request
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
var msg messages.IQ
|
||||||
|
if err = client.in.DecodeElement(&msg, element); err != nil {
|
||||||
|
client.log.Warn("is no iq: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if msg.Type != messages.IQTypeSet {
|
||||||
|
client.log.Warn("is no set iq")
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if msg.Error != nil {
|
||||||
|
client.log.Warn("iq with error: ", msg.Error.Code)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
type query struct {
|
||||||
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
|
||||||
|
Resource string `xml:"resource"`
|
||||||
|
}
|
||||||
|
q := &query{}
|
||||||
|
err = xml.Unmarshal(msg.Body, q)
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("is no iq bind: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if q.Resource == "" {
|
||||||
|
client.jid.Resource = makeResource()
|
||||||
|
} else {
|
||||||
|
client.jid.Resource = q.Resource
|
||||||
|
}
|
||||||
|
client.log = client.log.WithField("jid", client.jid.Full())
|
||||||
|
client.out.Encode(&messages.IQ{
|
||||||
|
Type: messages.IQTypeResult,
|
||||||
|
ID: msg.ID,
|
||||||
|
Body: []byte(fmt.Sprintf(
|
||||||
|
`<bind xmlns='%s'>
|
||||||
|
<jid>%s</jid>
|
||||||
|
</bind>`,
|
||||||
|
messages.NSBind, client.jid.Full())),
|
||||||
|
})
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendingClient state
|
||||||
|
type SendingClient struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *SendingClient) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "normal")
|
||||||
|
client.log.Debug("sending")
|
||||||
|
// sending
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case msg := <-client.messages:
|
||||||
|
err := client.out.Encode(msg)
|
||||||
|
client.log.Info(err)
|
||||||
|
case <-client.close:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
client.log.Debug("receiving")
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivingClient state
|
||||||
|
type ReceivingClient struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *ReceivingClient) Process(client *Client) (State, *Client) {
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
for _, extension := range client.Server.Extensions {
|
||||||
|
extension.Process(element, client)
|
||||||
|
}*/
|
||||||
|
client.log.Debug(element)
|
||||||
|
return state, client
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/messages"
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RegisterFormRequest struct {
|
||||||
|
Next State
|
||||||
|
element *xml.StartElement
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "register form request")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
var msg messages.IQ
|
||||||
|
if err := client.in.DecodeElement(&msg, state.element); err != nil {
|
||||||
|
client.log.Warn("is no iq: ", err)
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
if msg.Type != messages.IQTypeGet {
|
||||||
|
client.log.Warn("is no get iq")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
if msg.Error != nil {
|
||||||
|
client.log.Warn("iq with error: ", msg.Error.Code)
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
type query struct {
|
||||||
|
XMLName xml.Name `xml:"query"`
|
||||||
|
}
|
||||||
|
q := &query{}
|
||||||
|
err := xml.Unmarshal(msg.Body, q)
|
||||||
|
|
||||||
|
if q.XMLName.Space != messages.NSIQRegister || err != nil {
|
||||||
|
client.log.Warn("is no iq register: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
client.out.Encode(&messages.IQ{
|
||||||
|
Type: messages.IQTypeResult,
|
||||||
|
ID: msg.ID,
|
||||||
|
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
|
||||||
|
Choose a username and password for use with this service.
|
||||||
|
</instructions>
|
||||||
|
<username/>
|
||||||
|
<password/>
|
||||||
|
</query>`, messages.NSIQRegister)),
|
||||||
|
})
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegisterRequest struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *RegisterRequest) Process(client *Client) (State, *Client) {
|
||||||
|
client.log = client.log.WithField("state", "register request")
|
||||||
|
client.log.Debug("running")
|
||||||
|
defer client.log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
var msg messages.IQ
|
||||||
|
if err = client.in.DecodeElement(&msg, element); err != nil {
|
||||||
|
client.log.Warn("is no iq: ", err)
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
if msg.Type != messages.IQTypeGet {
|
||||||
|
client.log.Warn("is no get iq")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
if msg.Error != nil {
|
||||||
|
client.log.Warn("iq with error: ", msg.Error.Code)
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
type query struct {
|
||||||
|
XMLName xml.Name `xml:"query"`
|
||||||
|
Username string `xml:"username"`
|
||||||
|
Password string `xml:"password"`
|
||||||
|
}
|
||||||
|
q := &query{}
|
||||||
|
err = xml.Unmarshal(msg.Body, q)
|
||||||
|
if err != nil {
|
||||||
|
client.log.Warn("is no iq register: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
|
||||||
|
client.jid.Local = q.Username
|
||||||
|
client.log = client.log.WithField("jid", client.jid.Full())
|
||||||
|
account := model.NewAccount(client.jid, q.Password)
|
||||||
|
err = client.Server.Database.AddAccount(account)
|
||||||
|
if err != nil {
|
||||||
|
client.out.Encode(&messages.IQ{
|
||||||
|
Type: messages.IQTypeResult,
|
||||||
|
ID: msg.ID,
|
||||||
|
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
|
||||||
|
<username>%s</username>
|
||||||
|
<password>%s</password>
|
||||||
|
</query>`, messages.NSIQRegister, q.Username, q.Password)),
|
||||||
|
Error: &messages.Error{
|
||||||
|
Code: "409",
|
||||||
|
Type: "cancel",
|
||||||
|
Any: xml.Name{
|
||||||
|
Local: "conflict",
|
||||||
|
Space: "urn:ietf:params:xml:ns:xmpp-stanzas",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
client.log.Warn("database error: ", err)
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
client.account = account
|
||||||
|
client.out.Encode(&messages.IQ{
|
||||||
|
Type: messages.IQTypeResult,
|
||||||
|
ID: msg.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
client.log.Infof("registered client %s", client.jid.Bare())
|
||||||
|
return state.Next, client
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cookie is used to give a unique identifier to each request.
|
||||||
|
type Cookie uint64
|
||||||
|
|
||||||
|
func createCookie() Cookie {
|
||||||
|
var buf [8]byte
|
||||||
|
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
||||||
|
panic("Failed to read random bytes: " + err.Error())
|
||||||
|
}
|
||||||
|
return Cookie(binary.LittleEndian.Uint64(buf[:]))
|
||||||
|
}
|
||||||
|
func createCookieString() string {
|
||||||
|
return fmt.Sprintf("%x", createCookie())
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeResource() string {
|
||||||
|
var buf [16]byte
|
||||||
|
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
||||||
|
panic("Failed to read random bytes: " + err.Error())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%x", buf)
|
||||||
|
}
|
Reference in New Issue