lets encrypt + registration
This commit is contained in:
parent
8c60ef89c6
commit
800a5b1917
|
@ -1,6 +1,6 @@
|
|||
workspace:
|
||||
base: /go
|
||||
path: src/dev.sum7.eu/genofire/yaja
|
||||
path: src/github.com/genofire/yaja
|
||||
|
||||
pipeline:
|
||||
build:
|
||||
|
|
25
README.md
25
README.md
|
@ -1,3 +1,24 @@
|
|||
# yaja
|
||||
# yaja (yet another jabber server)
|
||||
|
||||
Yet Another Jabber Server
|
||||
```
|
||||
A small standalone jabber server, for easy deployment
|
||||
|
||||
Usage:
|
||||
yaja [command]
|
||||
|
||||
Available Commands:
|
||||
help Help about any command
|
||||
server Runs the yaja server
|
||||
|
||||
Flags:
|
||||
-h, --help help for yaja
|
||||
|
||||
Use "yaja [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Features (works already)
|
||||
- get certificate by lets encrypt
|
||||
- registration (for every possible ssl domain)
|
||||
|
||||
## Inspiration
|
||||
- [tam7t](https://github.com/tam7t/xmpp) a fork of [agl](https://github.com/agl)'s work
|
||||
|
|
|
@ -2,6 +2,7 @@ package cmd
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
@ -9,11 +10,12 @@ import (
|
|||
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
|
||||
"dev.sum7.eu/genofire/yaja/model"
|
||||
"dev.sum7.eu/genofire/yaja/model/config"
|
||||
"github.com/genofire/yaja/database"
|
||||
"github.com/genofire/yaja/model/config"
|
||||
|
||||
"dev.sum7.eu/genofire/yaja/server"
|
||||
"github.com/genofire/golang-lib/file"
|
||||
"github.com/genofire/golang-lib/worker"
|
||||
"github.com/genofire/yaja/server"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -22,32 +24,34 @@ import (
|
|||
var configPath string
|
||||
|
||||
var (
|
||||
configData *config.Config
|
||||
state *model.State
|
||||
configData = &config.Config{}
|
||||
db = &database.State{}
|
||||
statesaveWorker *worker.Worker
|
||||
srv *server.Server
|
||||
certs *tls.Config
|
||||
)
|
||||
|
||||
// serveCmd represents the serve command
|
||||
var serveCmd = &cobra.Command{
|
||||
Use: "serve",
|
||||
// serverCmd represents the serve command
|
||||
var serverCmd = &cobra.Command{
|
||||
Use: "server",
|
||||
Short: "Runs the yaja server",
|
||||
Example: "yaja serve -c /etc/yaja.conf",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
configData, err = config.ReadConfigFile(configPath)
|
||||
err = file.ReadTOML(configPath, configData)
|
||||
if err != nil {
|
||||
log.Fatal("unable to load config file:", err)
|
||||
}
|
||||
|
||||
state, err = model.ReadState(configData.StatePath)
|
||||
log.SetLevel(log.DebugLevel)
|
||||
|
||||
err = file.ReadJSON(configData.StatePath, db)
|
||||
if err != nil {
|
||||
log.Warn("unable to load state file:", err)
|
||||
}
|
||||
|
||||
statesaveWorker = worker.NewWorker(time.Minute, func() {
|
||||
model.SaveJSON(state, configData.StatePath)
|
||||
file.SaveJSON(configData.StatePath, db)
|
||||
log.Info("save state to:", configData.StatePath)
|
||||
})
|
||||
|
||||
|
@ -56,13 +60,18 @@ var serveCmd = &cobra.Command{
|
|||
Prompt: autocert.AcceptTOS,
|
||||
}
|
||||
|
||||
certs = &tls.Config{GetCertificate: m.GetCertificate}
|
||||
// https server to handle acme (by letsencrypt)
|
||||
httpServer := &http.Server{
|
||||
Addr: ":https",
|
||||
TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
|
||||
}
|
||||
go httpServer.ListenAndServeTLS("", "")
|
||||
|
||||
srv = &server.Server{
|
||||
TLSConfig: certs,
|
||||
State: state,
|
||||
PortClient: configData.PortClient,
|
||||
PortServer: configData.PortServer,
|
||||
TLSManager: &m,
|
||||
Database: db,
|
||||
ClientAddr: configData.Address.Client,
|
||||
ServerAddr: configData.Address.Server,
|
||||
}
|
||||
|
||||
go statesaveWorker.Start()
|
||||
|
@ -95,21 +104,24 @@ func quit() {
|
|||
srv.Close()
|
||||
statesaveWorker.Close()
|
||||
|
||||
model.SaveJSON(state, configData.StatePath)
|
||||
file.SaveJSON(configData.StatePath, db)
|
||||
}
|
||||
|
||||
func reload() {
|
||||
log.Info("start reloading...")
|
||||
configNewData, err := config.ReadConfigFile(configPath)
|
||||
var configNewData *config.Config
|
||||
err := file.ReadTOML(configPath, configNewData)
|
||||
if err != nil {
|
||||
log.Warn("unable to load config file:", err)
|
||||
return
|
||||
}
|
||||
|
||||
//TODO fetch changing address (to set restart)
|
||||
|
||||
if configNewData.StatePath != configData.StatePath {
|
||||
statesaveWorker.Close()
|
||||
statesaveWorker := worker.NewWorker(time.Minute, func() {
|
||||
model.SaveJSON(state, configNewData.StatePath)
|
||||
file.SaveJSON(configNewData.StatePath, db)
|
||||
log.Info("save state to:", configNewData.StatePath)
|
||||
})
|
||||
go statesaveWorker.Start()
|
||||
|
@ -130,17 +142,11 @@ func reload() {
|
|||
|
||||
newServer := &server.Server{
|
||||
TLSConfig: certs,
|
||||
State: state,
|
||||
PortClient: configNewData.PortClient,
|
||||
PortServer: configNewData.PortServer,
|
||||
Database: db,
|
||||
ClientAddr: configNewData.Address.Client,
|
||||
ServerAddr: configNewData.Address.Server,
|
||||
}
|
||||
|
||||
if configNewData.PortServer != configData.PortServer {
|
||||
restartServer = true
|
||||
}
|
||||
if configNewData.PortClient != configData.PortClient {
|
||||
restartServer = true
|
||||
}
|
||||
if restartServer {
|
||||
go srv.Start()
|
||||
//TODO should fetch new server error
|
||||
|
@ -153,6 +159,6 @@ func reload() {
|
|||
}
|
||||
|
||||
func init() {
|
||||
RootCmd.AddCommand(serveCmd)
|
||||
serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
||||
RootCmd.AddCommand(serverCmd)
|
||||
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
||||
}
|
|
@ -1,4 +1,6 @@
|
|||
tlsdir = "/tmp/ssl"
|
||||
state_path = "/tmp/yaja.json"
|
||||
port_client = 5222
|
||||
port_server = 5269
|
||||
|
||||
[address]
|
||||
client = [":5222"]
|
||||
server = [":5269"]
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/genofire/yaja/model"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
Domains map[string]*model.Domain `json:"domains"`
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (s *State) AddAccount(a *model.Account) error {
|
||||
if a.Local == "" {
|
||||
return errors.New("No localpart exists in account")
|
||||
}
|
||||
if d := a.Domain; d != nil {
|
||||
if d.FQDN == "" {
|
||||
return errors.New("No fqdn exists in domain")
|
||||
}
|
||||
s.Lock()
|
||||
domain, ok := s.Domains[d.FQDN]
|
||||
if !ok {
|
||||
if s.Domains == nil {
|
||||
s.Domains = make(map[string]*model.Domain)
|
||||
}
|
||||
s.Domains[d.FQDN] = d
|
||||
domain = d
|
||||
}
|
||||
s.Unlock()
|
||||
|
||||
domain.Lock()
|
||||
defer domain.Unlock()
|
||||
if domain.Accounts == nil {
|
||||
domain.Accounts = make(map[string]*model.Account)
|
||||
}
|
||||
_, ok = domain.Accounts[a.Local]
|
||||
if ok {
|
||||
return errors.New("exists already")
|
||||
}
|
||||
domain.Accounts[a.Local] = a
|
||||
a.Domain = d
|
||||
return nil
|
||||
}
|
||||
return errors.New("no give domain")
|
||||
}
|
||||
|
||||
func (s *State) Authenticate(jid *model.JID, password string) (bool, error) {
|
||||
logger := log.WithField("database", "auth")
|
||||
|
||||
if domain, ok := s.Domains[jid.Domain]; ok {
|
||||
if acc, ok := domain.Accounts[jid.Local]; ok {
|
||||
if acc.ValidatePassword(password) {
|
||||
return true, nil
|
||||
} else {
|
||||
logger.Debug("password not valid")
|
||||
}
|
||||
} else {
|
||||
logger.Debug("account not found")
|
||||
}
|
||||
} else {
|
||||
logger.Debug("domain not found")
|
||||
}
|
||||
return false, nil
|
||||
}
|
2
main.go
2
main.go
|
@ -1,6 +1,6 @@
|
|||
package main
|
||||
|
||||
import "dev.sum7.eu/genofire/yaja/cmd"
|
||||
import "github.com/genofire/yaja/cmd"
|
||||
|
||||
func main() {
|
||||
cmd.Execute()
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
package messages
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
// Error element
|
||||
type Error struct {
|
||||
XMLName xml.Name `xml:"jabber:client error"`
|
||||
Code string `xml:"code,attr"`
|
||||
Type string `xml:"type,attr"`
|
||||
Any xml.Name `xml:",any"`
|
||||
Text string `xml:"text"`
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package messages
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
type IQType string
|
||||
|
||||
const (
|
||||
IQTypeGet IQType = "get"
|
||||
IQTypeSet IQType = "set"
|
||||
IQTypeResult IQType = "result"
|
||||
IQTypeError IQType = "error"
|
||||
)
|
||||
|
||||
// IQ element - info/query
|
||||
type IQ struct {
|
||||
XMLName xml.Name `xml:"jabber:client iq"`
|
||||
From string `xml:"from,attr"`
|
||||
ID string `xml:"id,attr"`
|
||||
To string `xml:"to,attr"`
|
||||
Type IQType `xml:"type,attr"`
|
||||
Error *Error `xml:"error"`
|
||||
//Bind bindBind `xml:"bind"`
|
||||
Body []byte `xml:",innerxml"`
|
||||
// RosterRequest - better detection of iq's
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package messages
|
||||
|
||||
const (
|
||||
// NSStream stream namesapce
|
||||
NSStream = "http://etherx.jabber.org/streams"
|
||||
// NSTLS xmpp-tls xml namespace
|
||||
NSTLS = "urn:ietf:params:xml:ns:xmpp-tls"
|
||||
// NSSASL xmpp-sasl xml namespace
|
||||
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
||||
|
||||
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
|
||||
|
||||
// NSClient jabbet client namespace
|
||||
NSClient = "jabber:client"
|
||||
|
||||
NSIQRegister = "jabber:iq:register"
|
||||
|
||||
NSFeaturesIQRegister = "http://jabber.org/features/iq-register"
|
||||
)
|
|
@ -0,0 +1,32 @@
|
|||
package messages
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
type PresenceType string
|
||||
|
||||
const (
|
||||
PresenceTypeUnavailable PresenceType = "unavailable"
|
||||
PresenceTypeSubscribe PresenceType = "subscribe"
|
||||
PresenceTypeSubscribed PresenceType = "subscribed"
|
||||
PresenceTypeUnsubscribe PresenceType = "unsubscribe"
|
||||
PresenceTypeUnsubscribed PresenceType = "unsubscribed"
|
||||
PresenceTypeProbe PresenceType = "probe"
|
||||
PresenceTypeError PresenceType = "error"
|
||||
)
|
||||
|
||||
// Presence element
|
||||
type Presence struct {
|
||||
XMLName xml.Name `xml:"jabber:client presence"`
|
||||
From string `xml:"from,attr,omitempty"`
|
||||
ID string `xml:"id,attr,omitempty"`
|
||||
To string `xml:"to,attr,omitempty"`
|
||||
Type string `xml:"type,attr,omitempty"`
|
||||
Lang string `xml:"lang,attr,omitempty"`
|
||||
|
||||
Show string `xml:"show,omitempty"` // away, chat, dnd, xa
|
||||
Status string `xml:"status,omitempty"` // sb []clientText
|
||||
Priority string `xml:"priority,omitempty"`
|
||||
// Caps *ClientCaps `xml:"c"`
|
||||
Error *Error `xml:"error"`
|
||||
// Delay Delay `xml:"delay"`
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
package messages
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
// SASLAuth element
|
||||
type SASLAuth struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl auth"`
|
||||
Mechanism string `xml:"mechanism,attr"`
|
||||
Body string `xml:",chardata"`
|
||||
}
|
|
@ -6,8 +6,8 @@ import (
|
|||
)
|
||||
|
||||
type Domain struct {
|
||||
FQDN string
|
||||
Accounts map[string]*Account
|
||||
FQDN string `json:"-"`
|
||||
Accounts map[string]*Account `json:"users"`
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -29,8 +29,22 @@ func (d *Domain) UpdateAccount(a *Account) error {
|
|||
}
|
||||
|
||||
type Account struct {
|
||||
Local string
|
||||
Domain *Domain
|
||||
Local string `json:"-"`
|
||||
Domain *Domain `json:"-"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func NewAccount(jid *JID, password string) *Account {
|
||||
if jid == nil {
|
||||
return nil
|
||||
}
|
||||
return &Account{
|
||||
Local: jid.Local,
|
||||
Domain: &Domain{
|
||||
FQDN: jid.Domain,
|
||||
},
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) GetJID() *JID {
|
||||
|
@ -39,3 +53,7 @@ func (a *Account) GetJID() *JID {
|
|||
Local: a.Local,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) ValidatePassword(password string) bool {
|
||||
return a.Password == password
|
||||
}
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
// ReadConfigFile reads a config model from path of a yml file
|
||||
func ReadConfigFile(path string) (config *Config, err error) {
|
||||
config = &Config{}
|
||||
|
||||
file, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = toml.Unmarshal(file, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReadConfig(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
config, err := ReadConfigFile("../../config_example.conf")
|
||||
assert.NoError(err)
|
||||
assert.NotNil(config)
|
||||
|
||||
assert.Equal("/tmp/ssl", config.TLSDir)
|
||||
|
||||
config, err = ReadConfigFile("../config_example.co")
|
||||
assert.Nil(config)
|
||||
assert.Contains(err.Error(), "no such file or directory")
|
||||
|
||||
config, err = ReadConfigFile("testdata/config_panic.conf")
|
||||
assert.Nil(config)
|
||||
assert.Contains(err.Error(), "keys cannot contain")
|
||||
}
|
|
@ -1,8 +1,10 @@
|
|||
package config
|
||||
|
||||
type Config struct {
|
||||
TLSDir string `toml:"tlsdir"`
|
||||
StatePath string `toml:"state_path"`
|
||||
PortClient int `toml:"port_client"`
|
||||
PortServer int `toml:"port_server"`
|
||||
TLSDir string `toml:"tlsdir"`
|
||||
StatePath string `toml:"state_path"`
|
||||
Address struct {
|
||||
Client []string `toml:"client"`
|
||||
Server []string `toml:"server"`
|
||||
} `toml:"address"`
|
||||
}
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// SaveJSON to path
|
||||
func SaveJSON(input interface{}, outputFile string) {
|
||||
tmpFile := outputFile + ".tmp"
|
||||
|
||||
f, err := os.OpenFile(tmpFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
|
||||
err = json.NewEncoder(f).Encode(input)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
|
||||
f.Close()
|
||||
if err := os.Rename(tmpFile, outputFile); err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
}
|
|
@ -42,6 +42,8 @@ func (jid *JID) Bare() string {
|
|||
return jid.Domain
|
||||
}
|
||||
|
||||
func (jid *JID) String() string { return jid.Bare() }
|
||||
|
||||
// Full get the "full" jid as string
|
||||
func (jid *JID) Full() string {
|
||||
if jid.Resource != "" {
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
Domains map[string]*Domain
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func ReadState(path string) (state *State, err error) {
|
||||
state = &State{}
|
||||
|
||||
if f, err := os.Open(path); err == nil { // transform data to legacy meshviewer
|
||||
if err = json.NewDecoder(f).Decode(state); err == nil {
|
||||
return state, nil
|
||||
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) UpdateDomain(d *Domain) error {
|
||||
if d.FQDN == "" {
|
||||
return errors.New("No fqdn exists in domain")
|
||||
}
|
||||
s.Lock()
|
||||
s.Domains[d.FQDN] = d
|
||||
s.Unlock()
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"net"
|
||||
|
||||
"github.com/genofire/yaja/model"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
|
||||
Conn net.Conn
|
||||
out *xml.Encoder
|
||||
in *xml.Decoder
|
||||
|
||||
Server *Server
|
||||
jid *model.JID
|
||||
account *model.Account
|
||||
|
||||
messages chan interface{}
|
||||
close chan interface{}
|
||||
}
|
||||
|
||||
func NewClient(conn net.Conn, srv *Server) *Client {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.DebugLevel)
|
||||
client := &Client{
|
||||
Conn: conn,
|
||||
Server: srv,
|
||||
log: log.NewEntry(logger),
|
||||
in: xml.NewDecoder(conn),
|
||||
out: xml.NewEncoder(conn),
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func (client *Client) NewConnecting(conn net.Conn) {
|
||||
client.Conn = conn
|
||||
client.in = xml.NewDecoder(conn)
|
||||
client.out = xml.NewEncoder(conn)
|
||||
}
|
||||
|
||||
func (client *Client) Read() (*xml.StartElement, error) {
|
||||
for {
|
||||
nextToken, err := client.in.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch nextToken.(type) {
|
||||
case xml.StartElement:
|
||||
element := nextToken.(xml.StartElement)
|
||||
return &element, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) Close() {
|
||||
client.close <- true
|
||||
client.Conn.Close()
|
||||
}
|
|
@ -2,19 +2,82 @@ package server
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"dev.sum7.eu/genofire/yaja/model"
|
||||
"github.com/genofire/yaja/database"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
TLSConfig *tls.Config
|
||||
PortClient int
|
||||
PortServer int
|
||||
State *model.State
|
||||
TLSManager *autocert.Manager
|
||||
ClientAddr []string
|
||||
ServerAddr []string
|
||||
Database *database.State
|
||||
}
|
||||
|
||||
func (srv *Server) Start() {
|
||||
for _, addr := range srv.ServerAddr {
|
||||
socket, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Warn("create server socket: ", err.Error())
|
||||
break
|
||||
}
|
||||
go srv.listenServer(socket)
|
||||
}
|
||||
|
||||
for _, addr := range srv.ClientAddr {
|
||||
socket, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Warn("create client socket: ", err.Error())
|
||||
break
|
||||
}
|
||||
go srv.listenClient(socket)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) listenServer(s2s net.Listener) {
|
||||
for {
|
||||
conn, err := s2s.Accept()
|
||||
if err != nil {
|
||||
log.Warn("accepting server connection: ", err.Error())
|
||||
break
|
||||
}
|
||||
go srv.handleServer(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) listenClient(c2s net.Listener) {
|
||||
for {
|
||||
conn, err := c2s.Accept()
|
||||
if err != nil {
|
||||
log.Warn("accepting client connection: ", err.Error())
|
||||
break
|
||||
}
|
||||
go srv.handleClient(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) handleServer(conn net.Conn) {
|
||||
log.Info("new server connection:", conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (srv *Server) handleClient(conn net.Conn) {
|
||||
log.Info("new client connection:", conn.RemoteAddr())
|
||||
client := NewClient(conn, srv)
|
||||
state := ConnectionStartup()
|
||||
|
||||
for {
|
||||
state, client = state.Process(client)
|
||||
if state == nil {
|
||||
client.log.Info("disconnect")
|
||||
client.Close()
|
||||
//s.DisconnectBus <- Disconnect{Jid: client.jid}
|
||||
return
|
||||
}
|
||||
// run next state
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) Close() {
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
package server
|
||||
|
||||
// State processes the stream and moves to the next state
|
||||
type State interface {
|
||||
Process(client *Client) (State, *Client)
|
||||
}
|
|
@ -0,0 +1,344 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/genofire/yaja/messages"
|
||||
"github.com/genofire/yaja/model"
|
||||
)
|
||||
|
||||
// ConnectionStartup return steps through TCP TLS state
|
||||
func ConnectionStartup() State {
|
||||
receiving := &ReceivingClient{}
|
||||
sending := &SendingClient{Next: receiving}
|
||||
authedstream := &AuthedStream{Next: sending}
|
||||
authedstart := &AuthedStart{Next: authedstream}
|
||||
tlsauth := &SASLAuth{Next: authedstart}
|
||||
tlsstream := &TLSStream{Next: tlsauth}
|
||||
tlsupgrade := &TLSUpgrade{Next: tlsstream}
|
||||
stream := &Start{Next: tlsupgrade}
|
||||
return stream
|
||||
}
|
||||
|
||||
// Start state
|
||||
type Start struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process message
|
||||
func (state *Start) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "stream")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||
client.log.Warn("is no stream")
|
||||
return state, client
|
||||
}
|
||||
for _, attr := range element.Attr {
|
||||
if attr.Name.Local == "to" {
|
||||
client.jid = &model.JID{Domain: attr.Value}
|
||||
client.log = client.log.WithField("jid", client.jid.Full())
|
||||
}
|
||||
}
|
||||
if client.jid == nil {
|
||||
client.log.Warn("no 'to' domain readed")
|
||||
return nil, client
|
||||
}
|
||||
|
||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||
createCookie(), messages.NSClient, messages.NSStream)
|
||||
|
||||
fmt.Fprintf(client.Conn, `<stream:features>
|
||||
<starttls xmlns='%s'>
|
||||
<required/>
|
||||
</starttls>
|
||||
</stream:features>`,
|
||||
messages.NSStream)
|
||||
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
// TLSUpgrade state
|
||||
type TLSUpgrade struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process message
|
||||
func (state *TLSUpgrade) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "tls upgrade")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
|
||||
client.log.Warn("is no starttls")
|
||||
return state, client
|
||||
}
|
||||
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
|
||||
// perform the TLS handshake
|
||||
var tlsConfig *tls.Config
|
||||
if m := client.Server.TLSManager; m != nil {
|
||||
var cert *tls.Certificate
|
||||
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain})
|
||||
if err != nil {
|
||||
client.log.Warn("no cert in tls manger found: ", err)
|
||||
return nil, client
|
||||
}
|
||||
tlsConfig = &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}
|
||||
}
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = client.Server.TLSConfig
|
||||
if tlsConfig != nil {
|
||||
tlsConfig.ServerName = client.jid.Domain
|
||||
} else {
|
||||
client.log.Warn("no tls config found: ", err)
|
||||
return nil, client
|
||||
}
|
||||
}
|
||||
|
||||
tlsConn := tls.Server(client.Conn, tlsConfig)
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to tls handshake: ", err)
|
||||
return nil, client
|
||||
}
|
||||
// restart the Connection
|
||||
client.NewConnecting(tlsConn)
|
||||
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
// TLSStream state
|
||||
type TLSStream struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process messages
|
||||
func (state *TLSStream) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "tls stream")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||
client.log.Warn("is no stream")
|
||||
return state, client
|
||||
}
|
||||
|
||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||
createCookie(), messages.NSClient, messages.NSStream)
|
||||
|
||||
fmt.Fprintf(client.Conn, `<stream:features>
|
||||
<mechanisms xmlns='%s'>
|
||||
<mechanism>PLAIN</mechanism>
|
||||
</mechanisms>
|
||||
<register xmlns='%s'/>
|
||||
</stream:features>`,
|
||||
messages.NSSASL, messages.NSFeaturesIQRegister)
|
||||
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
// SASLAuth state
|
||||
type SASLAuth struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process messages
|
||||
func (state *SASLAuth) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "sasl auth")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
// read the full auth stanza
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
var auth messages.SASLAuth
|
||||
if err = client.in.DecodeElement(&auth, element); err != nil {
|
||||
client.log.Info("start substate for registration")
|
||||
return &RegisterFormRequest{
|
||||
element: element,
|
||||
Next: &RegisterRequest{
|
||||
Next: state.Next,
|
||||
},
|
||||
}, client
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
||||
if err != nil {
|
||||
client.log.Warn("body decode: ", err)
|
||||
return nil, client
|
||||
}
|
||||
info := strings.Split(string(data), "\x00")
|
||||
// should check that info[1] starts with client.jid
|
||||
client.jid.Local = info[1]
|
||||
client.log = client.log.WithField("jid", client.jid.Full())
|
||||
success, err := client.Server.Database.Authenticate(client.jid, info[2])
|
||||
if err != nil {
|
||||
client.log.Warn("auth: ", err)
|
||||
return nil, client
|
||||
}
|
||||
if success {
|
||||
client.log.Info("success auth")
|
||||
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
||||
return state.Next, client
|
||||
}
|
||||
client.log.Warn("failed auth")
|
||||
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
|
||||
return nil, client
|
||||
|
||||
}
|
||||
|
||||
// AuthedStart state
|
||||
type AuthedStart struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process messages
|
||||
func (state *AuthedStart) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "authed started")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
_, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||
createCookie(), messages.NSClient, messages.NSStream)
|
||||
|
||||
fmt.Fprintf(client.Conn, `<stream:features>
|
||||
<bind xmlns='%s'/>
|
||||
</stream:features>`,
|
||||
messages.NSBind)
|
||||
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
// AuthedStream state
|
||||
type AuthedStream struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process messages
|
||||
func (state *AuthedStream) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "authed stream")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
// check that it's a bind request
|
||||
// read bind request
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
var msg messages.IQ
|
||||
if err = client.in.DecodeElement(&msg, element); err != nil {
|
||||
client.log.Warn("is no iq: ", err)
|
||||
return nil, client
|
||||
}
|
||||
if msg.Type != messages.IQTypeSet {
|
||||
client.log.Warn("is no set iq")
|
||||
return nil, client
|
||||
}
|
||||
if msg.Error != nil {
|
||||
client.log.Warn("iq with error: ", msg.Error.Code)
|
||||
return nil, client
|
||||
}
|
||||
type query struct {
|
||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
|
||||
Resource string `xml:"resource"`
|
||||
}
|
||||
q := &query{}
|
||||
err = xml.Unmarshal(msg.Body, q)
|
||||
if err != nil {
|
||||
client.log.Warn("is no iq bind: ", err)
|
||||
return nil, client
|
||||
}
|
||||
if q.Resource == "" {
|
||||
client.jid.Resource = makeResource()
|
||||
} else {
|
||||
client.jid.Resource = q.Resource
|
||||
}
|
||||
client.log = client.log.WithField("jid", client.jid.Full())
|
||||
client.out.Encode(&messages.IQ{
|
||||
Type: messages.IQTypeResult,
|
||||
ID: msg.ID,
|
||||
Body: []byte(fmt.Sprintf(
|
||||
`<bind xmlns='%s'>
|
||||
<jid>%s</jid>
|
||||
</bind>`,
|
||||
messages.NSBind, client.jid.Full())),
|
||||
})
|
||||
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
// SendingClient state
|
||||
type SendingClient struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process messages
|
||||
func (state *SendingClient) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "normal")
|
||||
client.log.Debug("sending")
|
||||
// sending
|
||||
go func() {
|
||||
select {
|
||||
case msg := <-client.messages:
|
||||
err := client.out.Encode(msg)
|
||||
client.log.Info(err)
|
||||
case <-client.close:
|
||||
return
|
||||
}
|
||||
}()
|
||||
client.log.Debug("receiving")
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
// ReceivingClient state
|
||||
type ReceivingClient struct {
|
||||
}
|
||||
|
||||
// Process messages
|
||||
func (state *ReceivingClient) Process(client *Client) (State, *Client) {
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
/*
|
||||
for _, extension := range client.Server.Extensions {
|
||||
extension.Process(element, client)
|
||||
}*/
|
||||
client.log.Debug(element)
|
||||
return state, client
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
|
||||
"github.com/genofire/yaja/messages"
|
||||
"github.com/genofire/yaja/model"
|
||||
)
|
||||
|
||||
type RegisterFormRequest struct {
|
||||
Next State
|
||||
element *xml.StartElement
|
||||
}
|
||||
|
||||
// Process message
|
||||
func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "register form request")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
var msg messages.IQ
|
||||
if err := client.in.DecodeElement(&msg, state.element); err != nil {
|
||||
client.log.Warn("is no iq: ", err)
|
||||
return state, client
|
||||
}
|
||||
if msg.Type != messages.IQTypeGet {
|
||||
client.log.Warn("is no get iq")
|
||||
return state, client
|
||||
}
|
||||
if msg.Error != nil {
|
||||
client.log.Warn("iq with error: ", msg.Error.Code)
|
||||
return state, client
|
||||
}
|
||||
type query struct {
|
||||
XMLName xml.Name `xml:"query"`
|
||||
}
|
||||
q := &query{}
|
||||
err := xml.Unmarshal(msg.Body, q)
|
||||
|
||||
if q.XMLName.Space != messages.NSIQRegister || err != nil {
|
||||
client.log.Warn("is no iq register: ", err)
|
||||
return nil, client
|
||||
}
|
||||
client.out.Encode(&messages.IQ{
|
||||
Type: messages.IQTypeResult,
|
||||
ID: msg.ID,
|
||||
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
|
||||
Choose a username and password for use with this service.
|
||||
</instructions>
|
||||
<username/>
|
||||
<password/>
|
||||
</query>`, messages.NSIQRegister)),
|
||||
})
|
||||
return state.Next, client
|
||||
}
|
||||
|
||||
type RegisterRequest struct {
|
||||
Next State
|
||||
}
|
||||
|
||||
// Process message
|
||||
func (state *RegisterRequest) Process(client *Client) (State, *Client) {
|
||||
client.log = client.log.WithField("state", "register request")
|
||||
client.log.Debug("running")
|
||||
defer client.log.Debug("leave")
|
||||
|
||||
element, err := client.Read()
|
||||
if err != nil {
|
||||
client.log.Warn("unable to read: ", err)
|
||||
return nil, client
|
||||
}
|
||||
var msg messages.IQ
|
||||
if err = client.in.DecodeElement(&msg, element); err != nil {
|
||||
client.log.Warn("is no iq: ", err)
|
||||
return state, client
|
||||
}
|
||||
if msg.Type != messages.IQTypeGet {
|
||||
client.log.Warn("is no get iq")
|
||||
return state, client
|
||||
}
|
||||
if msg.Error != nil {
|
||||
client.log.Warn("iq with error: ", msg.Error.Code)
|
||||
return state, client
|
||||
}
|
||||
type query struct {
|
||||
XMLName xml.Name `xml:"query"`
|
||||
Username string `xml:"username"`
|
||||
Password string `xml:"password"`
|
||||
}
|
||||
q := &query{}
|
||||
err = xml.Unmarshal(msg.Body, q)
|
||||
if err != nil {
|
||||
client.log.Warn("is no iq register: ", err)
|
||||
return nil, client
|
||||
}
|
||||
|
||||
client.jid.Local = q.Username
|
||||
client.log = client.log.WithField("jid", client.jid.Full())
|
||||
account := model.NewAccount(client.jid, q.Password)
|
||||
err = client.Server.Database.AddAccount(account)
|
||||
if err != nil {
|
||||
client.out.Encode(&messages.IQ{
|
||||
Type: messages.IQTypeResult,
|
||||
ID: msg.ID,
|
||||
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
|
||||
<username>%s</username>
|
||||
<password>%s</password>
|
||||
</query>`, messages.NSIQRegister, q.Username, q.Password)),
|
||||
Error: &messages.Error{
|
||||
Code: "409",
|
||||
Type: "cancel",
|
||||
Any: xml.Name{
|
||||
Local: "conflict",
|
||||
Space: "urn:ietf:params:xml:ns:xmpp-stanzas",
|
||||
},
|
||||
},
|
||||
})
|
||||
client.log.Warn("database error: ", err)
|
||||
return state, client
|
||||
}
|
||||
client.account = account
|
||||
client.out.Encode(&messages.IQ{
|
||||
Type: messages.IQTypeResult,
|
||||
ID: msg.ID,
|
||||
})
|
||||
|
||||
client.log.Infof("registered client %s", client.jid.Bare())
|
||||
return state.Next, client
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Cookie is used to give a unique identifier to each request.
|
||||
type Cookie uint64
|
||||
|
||||
func createCookie() Cookie {
|
||||
var buf [8]byte
|
||||
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
||||
panic("Failed to read random bytes: " + err.Error())
|
||||
}
|
||||
return Cookie(binary.LittleEndian.Uint64(buf[:]))
|
||||
}
|
||||
func createCookieString() string {
|
||||
return fmt.Sprintf("%x", createCookie())
|
||||
}
|
||||
|
||||
func makeResource() string {
|
||||
var buf [16]byte
|
||||
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
||||
panic("Failed to read random bytes: " + err.Error())
|
||||
}
|
||||
return fmt.Sprintf("%x", buf)
|
||||
}
|
Reference in New Issue