restructur code in packages
This commit is contained in:
parent
a079961c8b
commit
1aceea7133
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
"github.com/genofire/yaja/database"
|
"github.com/genofire/yaja/database"
|
||||||
"github.com/genofire/yaja/model/config"
|
"github.com/genofire/yaja/model/config"
|
||||||
|
"github.com/genofire/yaja/server/extension"
|
||||||
|
|
||||||
"github.com/genofire/golang-lib/file"
|
"github.com/genofire/golang-lib/file"
|
||||||
"github.com/genofire/golang-lib/worker"
|
"github.com/genofire/golang-lib/worker"
|
||||||
|
@ -29,6 +30,7 @@ var (
|
||||||
statesaveWorker *worker.Worker
|
statesaveWorker *worker.Worker
|
||||||
srv *server.Server
|
srv *server.Server
|
||||||
certs *tls.Config
|
certs *tls.Config
|
||||||
|
extensions []extension.Extension
|
||||||
)
|
)
|
||||||
|
|
||||||
// serverCmd represents the serve command
|
// serverCmd represents the serve command
|
||||||
|
@ -81,6 +83,7 @@ var serverCmd = &cobra.Command{
|
||||||
LoggingClient: configData.Logging.LevelClient,
|
LoggingClient: configData.Logging.LevelClient,
|
||||||
RegisterEnable: configData.Register.Enable,
|
RegisterEnable: configData.Register.Enable,
|
||||||
RegisterDomains: configData.Register.Domains,
|
RegisterDomains: configData.Register.Domains,
|
||||||
|
Extensions: extensions,
|
||||||
}
|
}
|
||||||
|
|
||||||
go statesaveWorker.Start()
|
go statesaveWorker.Start()
|
||||||
|
@ -161,6 +164,7 @@ func reload() {
|
||||||
LoggingClient: configNewData.Logging.LevelClient,
|
LoggingClient: configNewData.Logging.LevelClient,
|
||||||
RegisterEnable: configNewData.Register.Enable,
|
RegisterEnable: configNewData.Register.Enable,
|
||||||
RegisterDomains: configNewData.Register.Domains,
|
RegisterDomains: configNewData.Register.Domains,
|
||||||
|
Extensions: extensions,
|
||||||
}
|
}
|
||||||
log.Warn("reloading need a restart:")
|
log.Warn("reloading need a restart:")
|
||||||
go newServer.Start()
|
go newServer.Start()
|
||||||
|
@ -176,4 +180,6 @@ func reload() {
|
||||||
func init() {
|
func init() {
|
||||||
RootCmd.AddCommand(serverCmd)
|
RootCmd.AddCommand(serverCmd)
|
||||||
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
|
||||||
|
|
||||||
|
extensions = append(extensions, &extension.Message{}, &extension.Roster{Database: db})
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ type Account struct {
|
||||||
Local string `json:"-"`
|
Local string `json:"-"`
|
||||||
Domain *Domain `json:"-"`
|
Domain *Domain `json:"-"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
Roster map[string]*Buddy `json:"roster"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccount(jid *JID, password string) *Account {
|
func NewAccount(jid *JID, password string) *Account {
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
package model
|
||||||
|
|
||||||
|
const (
|
||||||
|
SubscriptionNone = iota
|
||||||
|
SubscriptionTo
|
||||||
|
SubscriptionFrom
|
||||||
|
SubscriptionBoth
|
||||||
|
AskNone = iota
|
||||||
|
AskSubscribe
|
||||||
|
)
|
||||||
|
|
||||||
|
type Buddy struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
Subscription int `json:"subscription"`
|
||||||
|
Ask int `json:"ask"`
|
||||||
|
}
|
|
@ -1,76 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/xml"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/genofire/yaja/model"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
log *log.Entry
|
|
||||||
|
|
||||||
Conn net.Conn
|
|
||||||
out *xml.Encoder
|
|
||||||
in *xml.Decoder
|
|
||||||
|
|
||||||
Server *Server
|
|
||||||
jid *model.JID
|
|
||||||
account *model.Account
|
|
||||||
|
|
||||||
messages chan interface{}
|
|
||||||
close chan interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(conn net.Conn, srv *Server) *Client {
|
|
||||||
logger := log.New()
|
|
||||||
logger.SetLevel(srv.LoggingClient)
|
|
||||||
client := &Client{
|
|
||||||
Conn: conn,
|
|
||||||
Server: srv,
|
|
||||||
log: log.NewEntry(logger),
|
|
||||||
in: xml.NewDecoder(conn),
|
|
||||||
out: xml.NewEncoder(conn),
|
|
||||||
}
|
|
||||||
return client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client *Client) NewConnecting(conn net.Conn) {
|
|
||||||
client.Conn = conn
|
|
||||||
client.in = xml.NewDecoder(conn)
|
|
||||||
client.out = xml.NewEncoder(conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client *Client) Read() (*xml.StartElement, error) {
|
|
||||||
for {
|
|
||||||
nextToken, err := client.in.Token()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
switch nextToken.(type) {
|
|
||||||
case xml.StartElement:
|
|
||||||
element := nextToken.(xml.StartElement)
|
|
||||||
return &element, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client *Client) DomainRegisterAllowed() bool {
|
|
||||||
if client.jid.Domain == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range client.Server.RegisterDomains {
|
|
||||||
if domain == client.jid.Domain {
|
|
||||||
|
|
||||||
return !client.Server.RegisterEnable
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return client.Server.RegisterEnable
|
|
||||||
}
|
|
||||||
|
|
||||||
func (client *Client) Close() {
|
|
||||||
client.close <- true
|
|
||||||
client.Conn.Close()
|
|
||||||
}
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Extension interface {
|
||||||
|
Process(*xml.StartElement, *utils.Client) bool
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Extension
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) Process(element *xml.StartElement, client *utils.Client) bool {
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package extension
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/database"
|
||||||
|
"github.com/genofire/yaja/messages"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Roster struct {
|
||||||
|
Extension
|
||||||
|
Database *database.State
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Roster) Process(element *xml.StartElement, client *utils.Client) bool {
|
||||||
|
var msg messages.IQ
|
||||||
|
if err := client.In.DecodeElement(&msg, element); err != nil {
|
||||||
|
client.Log.Warn("is no iq: ", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if msg.Type != messages.IQTypeGet {
|
||||||
|
client.Log.Warn("is no get iq")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -5,6 +5,10 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/genofire/yaja/database"
|
"github.com/genofire/yaja/database"
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
"github.com/genofire/yaja/server/extension"
|
||||||
|
"github.com/genofire/yaja/server/toclient"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
)
|
)
|
||||||
|
@ -16,8 +20,9 @@ type Server struct {
|
||||||
ServerAddr []string
|
ServerAddr []string
|
||||||
Database *database.State
|
Database *database.State
|
||||||
LoggingClient log.Level
|
LoggingClient log.Level
|
||||||
RegisterEnable bool `toml:"enable"`
|
RegisterEnable bool
|
||||||
RegisterDomains []string `toml:"domains"`
|
RegisterDomains []string
|
||||||
|
Extensions []extension.Extension
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Start() {
|
func (srv *Server) Start() {
|
||||||
|
@ -68,13 +73,13 @@ func (srv *Server) handleServer(conn net.Conn) {
|
||||||
|
|
||||||
func (srv *Server) handleClient(conn net.Conn) {
|
func (srv *Server) handleClient(conn net.Conn) {
|
||||||
log.Info("new client connection:", conn.RemoteAddr())
|
log.Info("new client connection:", conn.RemoteAddr())
|
||||||
client := NewClient(conn, srv)
|
client := utils.NewClient(conn, srv.LoggingClient)
|
||||||
state := ConnectionStartup()
|
state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
state, client = state.Process(client)
|
state, client = state.Process(client)
|
||||||
if state == nil {
|
if state == nil {
|
||||||
client.log.Info("disconnect")
|
client.Log.Info("disconnect")
|
||||||
client.Close()
|
client.Close()
|
||||||
//s.DisconnectBus <- Disconnect{Jid: client.jid}
|
//s.DisconnectBus <- Disconnect{Jid: client.jid}
|
||||||
return
|
return
|
||||||
|
@ -83,6 +88,19 @@ func (srv *Server) handleClient(conn net.Conn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) DomainRegisterAllowed(jid *model.JID) bool {
|
||||||
|
if jid.Domain == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range srv.RegisterDomains {
|
||||||
|
if domain == jid.Domain {
|
||||||
|
return !srv.RegisterEnable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return srv.RegisterEnable
|
||||||
|
}
|
||||||
|
|
||||||
func (srv *Server) Close() {
|
func (srv *Server) Close() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
// State processes the stream and moves to the next state
|
|
||||||
type State interface {
|
|
||||||
Process(client *Client) (State, *Client)
|
|
||||||
}
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/messages"
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionStartup return steps through TCP TLS state
|
||||||
|
func ConnectionStartup(after State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager) State {
|
||||||
|
tlsupgrade := &TLSUpgrade{
|
||||||
|
Next: after,
|
||||||
|
tlsconfig: tlsconfig,
|
||||||
|
tlsmgmt: tlsmgmt,
|
||||||
|
}
|
||||||
|
stream := &Start{Next: tlsupgrade}
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start state
|
||||||
|
type Start struct {
|
||||||
|
Next State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *Start) Process(client *utils.Client) (State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "stream")
|
||||||
|
client.Log.Debug("running")
|
||||||
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
|
client.Log.Warn("is no stream")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
for _, attr := range element.Attr {
|
||||||
|
if attr.Name.Local == "to" {
|
||||||
|
client.JID = &model.JID{Domain: attr.Value}
|
||||||
|
client.Log = client.Log.WithField("jid", client.JID.Full())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if client.JID == nil {
|
||||||
|
client.Log.Warn("no 'to' domain readed")
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<starttls xmlns='%s'>
|
||||||
|
<required/>
|
||||||
|
</starttls>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSStream)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSUpgrade state
|
||||||
|
type TLSUpgrade struct {
|
||||||
|
Next State
|
||||||
|
tlsconfig *tls.Config
|
||||||
|
tlsmgmt *autocert.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process message
|
||||||
|
func (state *TLSUpgrade) Process(client *utils.Client) (State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "tls upgrade")
|
||||||
|
client.Log.Debug("running")
|
||||||
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
|
||||||
|
client.Log.Warn("is no starttls")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
|
||||||
|
// perform the TLS handshake
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
if m := state.tlsmgmt; m != nil {
|
||||||
|
var cert *tls.Certificate
|
||||||
|
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.JID.Domain})
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("no cert in tls manger found: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
tlsConfig = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = state.tlsconfig
|
||||||
|
if tlsConfig != nil {
|
||||||
|
tlsConfig.ServerName = client.JID.Domain
|
||||||
|
} else {
|
||||||
|
client.Log.Warn("no tls config found: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Server(client.Conn, tlsConfig)
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to tls handshake: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
// restart the Connection
|
||||||
|
client.SetConnecting(tlsConn)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
package state
|
||||||
|
|
||||||
|
import "github.com/genofire/yaja/server/utils"
|
||||||
|
|
||||||
|
// State processes the stream and moves to the next state
|
||||||
|
type State interface {
|
||||||
|
Process(client *utils.Client) (State, *utils.Client)
|
||||||
|
}
|
|
@ -1,353 +0,0 @@
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/xml"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/genofire/yaja/messages"
|
|
||||||
"github.com/genofire/yaja/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConnectionStartup return steps through TCP TLS state
|
|
||||||
func ConnectionStartup() State {
|
|
||||||
receiving := &ReceivingClient{}
|
|
||||||
sending := &SendingClient{Next: receiving}
|
|
||||||
authedstream := &AuthedStream{Next: sending}
|
|
||||||
authedstart := &AuthedStart{Next: authedstream}
|
|
||||||
tlsauth := &SASLAuth{Next: authedstart}
|
|
||||||
tlsstream := &TLSStream{Next: tlsauth}
|
|
||||||
tlsupgrade := &TLSUpgrade{Next: tlsstream}
|
|
||||||
stream := &Start{Next: tlsupgrade}
|
|
||||||
return stream
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start state
|
|
||||||
type Start struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process message
|
|
||||||
func (state *Start) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "stream")
|
|
||||||
client.log.Debug("running")
|
|
||||||
defer client.log.Debug("leave")
|
|
||||||
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
|
||||||
client.log.Warn("is no stream")
|
|
||||||
return state, client
|
|
||||||
}
|
|
||||||
for _, attr := range element.Attr {
|
|
||||||
if attr.Name.Local == "to" {
|
|
||||||
client.jid = &model.JID{Domain: attr.Value}
|
|
||||||
client.log = client.log.WithField("jid", client.jid.Full())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if client.jid == nil {
|
|
||||||
client.log.Warn("no 'to' domain readed")
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
|
||||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
|
||||||
createCookie(), messages.NSClient, messages.NSStream)
|
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
|
||||||
<starttls xmlns='%s'>
|
|
||||||
<required/>
|
|
||||||
</starttls>
|
|
||||||
</stream:features>`,
|
|
||||||
messages.NSStream)
|
|
||||||
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// TLSUpgrade state
|
|
||||||
type TLSUpgrade struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process message
|
|
||||||
func (state *TLSUpgrade) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "tls upgrade")
|
|
||||||
client.log.Debug("running")
|
|
||||||
defer client.log.Debug("leave")
|
|
||||||
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
|
|
||||||
client.log.Warn("is no starttls")
|
|
||||||
return state, client
|
|
||||||
}
|
|
||||||
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
|
|
||||||
// perform the TLS handshake
|
|
||||||
var tlsConfig *tls.Config
|
|
||||||
if m := client.Server.TLSManager; m != nil {
|
|
||||||
var cert *tls.Certificate
|
|
||||||
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain})
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("no cert in tls manger found: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
tlsConfig = &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{*cert},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tlsConfig == nil {
|
|
||||||
tlsConfig = client.Server.TLSConfig
|
|
||||||
if tlsConfig != nil {
|
|
||||||
tlsConfig.ServerName = client.jid.Domain
|
|
||||||
} else {
|
|
||||||
client.log.Warn("no tls config found: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConn := tls.Server(client.Conn, tlsConfig)
|
|
||||||
err = tlsConn.Handshake()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to tls handshake: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
// restart the Connection
|
|
||||||
client.NewConnecting(tlsConn)
|
|
||||||
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// TLSStream state
|
|
||||||
type TLSStream struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *TLSStream) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "tls stream")
|
|
||||||
client.log.Debug("running")
|
|
||||||
defer client.log.Debug("leave")
|
|
||||||
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
|
||||||
client.log.Warn("is no stream")
|
|
||||||
return state, client
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
|
||||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
|
||||||
createCookie(), messages.NSClient, messages.NSStream)
|
|
||||||
|
|
||||||
if client.DomainRegisterAllowed() {
|
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
|
||||||
<mechanisms xmlns='%s'>
|
|
||||||
<mechanism>PLAIN</mechanism>
|
|
||||||
</mechanisms>
|
|
||||||
<register xmlns='%s'/>
|
|
||||||
</stream:features>`,
|
|
||||||
messages.NSSASL, messages.NSFeaturesIQRegister)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
|
||||||
<mechanisms xmlns='%s'>
|
|
||||||
<mechanism>PLAIN</mechanism>
|
|
||||||
</mechanisms>
|
|
||||||
</stream:features>`,
|
|
||||||
messages.NSSASL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// SASLAuth state
|
|
||||||
type SASLAuth struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *SASLAuth) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "sasl auth")
|
|
||||||
client.log.Debug("running")
|
|
||||||
defer client.log.Debug("leave")
|
|
||||||
|
|
||||||
// read the full auth stanza
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
var auth messages.SASLAuth
|
|
||||||
if err = client.in.DecodeElement(&auth, element); err != nil {
|
|
||||||
client.log.Info("start substate for registration")
|
|
||||||
return &RegisterFormRequest{
|
|
||||||
element: element,
|
|
||||||
Next: &RegisterRequest{
|
|
||||||
Next: state.Next,
|
|
||||||
},
|
|
||||||
}, client
|
|
||||||
}
|
|
||||||
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("body decode: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
info := strings.Split(string(data), "\x00")
|
|
||||||
// should check that info[1] starts with client.jid
|
|
||||||
client.jid.Local = info[1]
|
|
||||||
client.log = client.log.WithField("jid", client.jid.Full())
|
|
||||||
success, err := client.Server.Database.Authenticate(client.jid, info[2])
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("auth: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if success {
|
|
||||||
client.log.Info("success auth")
|
|
||||||
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
client.log.Warn("failed auth")
|
|
||||||
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
|
|
||||||
return nil, client
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthedStart state
|
|
||||||
type AuthedStart struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *AuthedStart) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "authed started")
|
|
||||||
client.log.Debug("running")
|
|
||||||
defer client.log.Debug("leave")
|
|
||||||
|
|
||||||
_, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
|
||||||
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
|
||||||
createCookie(), messages.NSClient, messages.NSStream)
|
|
||||||
|
|
||||||
fmt.Fprintf(client.Conn, `<stream:features>
|
|
||||||
<bind xmlns='%s'/>
|
|
||||||
</stream:features>`,
|
|
||||||
messages.NSBind)
|
|
||||||
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthedStream state
|
|
||||||
type AuthedStream struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *AuthedStream) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "authed stream")
|
|
||||||
client.log.Debug("running")
|
|
||||||
defer client.log.Debug("leave")
|
|
||||||
|
|
||||||
// check that it's a bind request
|
|
||||||
// read bind request
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
var msg messages.IQ
|
|
||||||
if err = client.in.DecodeElement(&msg, element); err != nil {
|
|
||||||
client.log.Warn("is no iq: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if msg.Type != messages.IQTypeSet {
|
|
||||||
client.log.Warn("is no set iq")
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if msg.Error != nil {
|
|
||||||
client.log.Warn("iq with error: ", msg.Error.Code)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
type query struct {
|
|
||||||
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
|
|
||||||
Resource string `xml:"resource"`
|
|
||||||
}
|
|
||||||
q := &query{}
|
|
||||||
err = xml.Unmarshal(msg.Body, q)
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("is no iq bind: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
if q.Resource == "" {
|
|
||||||
client.jid.Resource = makeResource()
|
|
||||||
} else {
|
|
||||||
client.jid.Resource = q.Resource
|
|
||||||
}
|
|
||||||
client.log = client.log.WithField("jid", client.jid.Full())
|
|
||||||
client.out.Encode(&messages.IQ{
|
|
||||||
Type: messages.IQTypeResult,
|
|
||||||
ID: msg.ID,
|
|
||||||
Body: []byte(fmt.Sprintf(
|
|
||||||
`<bind xmlns='%s'>
|
|
||||||
<jid>%s</jid>
|
|
||||||
</bind>`,
|
|
||||||
messages.NSBind, client.jid.Full())),
|
|
||||||
})
|
|
||||||
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendingClient state
|
|
||||||
type SendingClient struct {
|
|
||||||
Next State
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *SendingClient) Process(client *Client) (State, *Client) {
|
|
||||||
client.log = client.log.WithField("state", "normal")
|
|
||||||
client.log.Debug("sending")
|
|
||||||
// sending
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case msg := <-client.messages:
|
|
||||||
err := client.out.Encode(msg)
|
|
||||||
client.log.Info(err)
|
|
||||||
case <-client.close:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
client.log.Debug("receiving")
|
|
||||||
return state.Next, client
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReceivingClient state
|
|
||||||
type ReceivingClient struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process messages
|
|
||||||
func (state *ReceivingClient) Process(client *Client) (State, *Client) {
|
|
||||||
element, err := client.Read()
|
|
||||||
if err != nil {
|
|
||||||
client.log.Warn("unable to read: ", err)
|
|
||||||
return nil, client
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
for _, extension := range client.Server.Extensions {
|
|
||||||
extension.Process(element, client)
|
|
||||||
}*/
|
|
||||||
client.log.Debug(element)
|
|
||||||
return state, client
|
|
||||||
}
|
|
|
@ -0,0 +1,272 @@
|
||||||
|
package toclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/database"
|
||||||
|
"github.com/genofire/yaja/messages"
|
||||||
|
"github.com/genofire/yaja/server/extension"
|
||||||
|
"github.com/genofire/yaja/server/state"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionStartup return steps through TCP TLS state
|
||||||
|
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed) state.State {
|
||||||
|
receiving := &ReceivingClient{}
|
||||||
|
sending := &SendingClient{Next: receiving}
|
||||||
|
authedstream := &AuthedStream{Next: sending}
|
||||||
|
authedstart := &AuthedStart{Next: authedstream}
|
||||||
|
tlsauth := &SASLAuth{
|
||||||
|
Next: authedstart,
|
||||||
|
database: db,
|
||||||
|
domainRegisterAllowed: registerAllowed,
|
||||||
|
}
|
||||||
|
tlsstream := &TLSStream{
|
||||||
|
Next: tlsauth,
|
||||||
|
domainRegisterAllowed: registerAllowed,
|
||||||
|
}
|
||||||
|
return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSStream state
|
||||||
|
type TLSStream struct {
|
||||||
|
Next state.State
|
||||||
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "tls stream")
|
||||||
|
client.Log.Debug("running")
|
||||||
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
|
||||||
|
client.Log.Warn("is no stream")
|
||||||
|
return state, client
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
if state.domainRegisterAllowed(client.JID) {
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<mechanisms xmlns='%s'>
|
||||||
|
<mechanism>PLAIN</mechanism>
|
||||||
|
</mechanisms>
|
||||||
|
<register xmlns='%s'/>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSSASL, messages.NSFeaturesIQRegister)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<mechanisms xmlns='%s'>
|
||||||
|
<mechanism>PLAIN</mechanism>
|
||||||
|
</mechanisms>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSSASL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// SASLAuth state
|
||||||
|
type SASLAuth struct {
|
||||||
|
Next state.State
|
||||||
|
database *database.State
|
||||||
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "sasl auth")
|
||||||
|
client.Log.Debug("running")
|
||||||
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
|
// read the full auth stanza
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
var auth messages.SASLAuth
|
||||||
|
if err = client.In.DecodeElement(&auth, element); err != nil {
|
||||||
|
client.Log.Info("start substate for registration")
|
||||||
|
return &RegisterFormRequest{
|
||||||
|
element: element,
|
||||||
|
domainRegisterAllowed: state.domainRegisterAllowed,
|
||||||
|
Next: &RegisterRequest{
|
||||||
|
domainRegisterAllowed: state.domainRegisterAllowed,
|
||||||
|
database: state.database,
|
||||||
|
Next: state.Next,
|
||||||
|
},
|
||||||
|
}, client
|
||||||
|
}
|
||||||
|
data, err := base64.StdEncoding.DecodeString(auth.Body)
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("body decode: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
info := strings.Split(string(data), "\x00")
|
||||||
|
// should check that info[1] starts with client.JID
|
||||||
|
client.JID.Local = info[1]
|
||||||
|
client.Log = client.Log.WithField("jid", client.JID.Full())
|
||||||
|
success, err := state.database.Authenticate(client.JID, info[2])
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("auth: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if success {
|
||||||
|
client.Log.Info("success auth")
|
||||||
|
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
client.Log.Warn("failed auth")
|
||||||
|
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
|
||||||
|
return nil, client
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthedStart state
|
||||||
|
type AuthedStart struct {
|
||||||
|
Next state.State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "authed started")
|
||||||
|
client.Log.Debug("running")
|
||||||
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
|
_, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
|
||||||
|
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
|
||||||
|
utils.CreateCookie(), messages.NSClient, messages.NSStream)
|
||||||
|
|
||||||
|
fmt.Fprintf(client.Conn, `<stream:features>
|
||||||
|
<bind xmlns='%s'/>
|
||||||
|
</stream:features>`,
|
||||||
|
messages.NSBind)
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthedStream state
|
||||||
|
type AuthedStream struct {
|
||||||
|
Next state.State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "authed stream")
|
||||||
|
client.Log.Debug("running")
|
||||||
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
|
// check that it's a bind request
|
||||||
|
// read bind request
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
var msg messages.IQ
|
||||||
|
if err = client.In.DecodeElement(&msg, element); err != nil {
|
||||||
|
client.Log.Warn("is no iq: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if msg.Type != messages.IQTypeSet {
|
||||||
|
client.Log.Warn("is no set iq")
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if msg.Error != nil {
|
||||||
|
client.Log.Warn("iq with error: ", msg.Error.Code)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
type query struct {
|
||||||
|
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
|
||||||
|
Resource string `xml:"resource"`
|
||||||
|
}
|
||||||
|
q := &query{}
|
||||||
|
err = xml.Unmarshal(msg.Body, q)
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("is no iq bind: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
if q.Resource == "" {
|
||||||
|
client.JID.Resource = makeResource()
|
||||||
|
} else {
|
||||||
|
client.JID.Resource = q.Resource
|
||||||
|
}
|
||||||
|
client.Log = client.Log.WithField("jid", client.JID.Full())
|
||||||
|
client.Out.Encode(&messages.IQ{
|
||||||
|
Type: messages.IQTypeResult,
|
||||||
|
ID: msg.ID,
|
||||||
|
Body: []byte(fmt.Sprintf(
|
||||||
|
`<bind xmlns='%s'>
|
||||||
|
<jid>%s</jid>
|
||||||
|
</bind>`,
|
||||||
|
messages.NSBind, client.JID.Full())),
|
||||||
|
})
|
||||||
|
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendingClient state
|
||||||
|
type SendingClient struct {
|
||||||
|
Next state.State
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
|
client.Log = client.Log.WithField("state", "normal")
|
||||||
|
client.Log.Debug("sending")
|
||||||
|
// sending
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case msg := <-client.Messages:
|
||||||
|
err := client.Out.Encode(msg)
|
||||||
|
client.Log.Info(err)
|
||||||
|
case <-client.OnClose():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
client.Log.Debug("receiving")
|
||||||
|
return state.Next, client
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceivingClient state
|
||||||
|
type ReceivingClient struct {
|
||||||
|
Extensions []extension.Extension
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process messages
|
||||||
|
func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
|
element, err := client.Read()
|
||||||
|
if err != nil {
|
||||||
|
client.Log.Warn("unable to read: ", err)
|
||||||
|
return nil, client
|
||||||
|
}
|
||||||
|
count := 0
|
||||||
|
for _, extension := range state.Extensions {
|
||||||
|
if extension.Process(element, client) {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
client.Log.WithField("extension", count).Debug(element)
|
||||||
|
}
|
||||||
|
return state, client
|
||||||
|
}
|
|
@ -1,40 +1,44 @@
|
||||||
package server
|
package toclient
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/database"
|
||||||
"github.com/genofire/yaja/messages"
|
"github.com/genofire/yaja/messages"
|
||||||
"github.com/genofire/yaja/model"
|
"github.com/genofire/yaja/model"
|
||||||
|
"github.com/genofire/yaja/server/state"
|
||||||
|
"github.com/genofire/yaja/server/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RegisterFormRequest struct {
|
type RegisterFormRequest struct {
|
||||||
Next State
|
Next state.State
|
||||||
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
element *xml.StartElement
|
element *xml.StartElement
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message
|
// Process message
|
||||||
func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
|
func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
client.log = client.log.WithField("state", "register form request")
|
client.Log = client.Log.WithField("state", "register form request")
|
||||||
client.log.Debug("running")
|
client.Log.Debug("running")
|
||||||
defer client.log.Debug("leave")
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
if !client.DomainRegisterAllowed() {
|
if !state.domainRegisterAllowed(client.JID) {
|
||||||
client.log.Error("unpossible to reach this state, register on this domain is not allowed")
|
client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
|
||||||
return nil, client
|
return nil, client
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg messages.IQ
|
var msg messages.IQ
|
||||||
if err := client.in.DecodeElement(&msg, state.element); err != nil {
|
if err := client.In.DecodeElement(&msg, state.element); err != nil {
|
||||||
client.log.Warn("is no iq: ", err)
|
client.Log.Warn("is no iq: ", err)
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
if msg.Type != messages.IQTypeGet {
|
if msg.Type != messages.IQTypeGet {
|
||||||
client.log.Warn("is no get iq")
|
client.Log.Warn("is no get iq")
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
if msg.Error != nil {
|
if msg.Error != nil {
|
||||||
client.log.Warn("iq with error: ", msg.Error.Code)
|
client.Log.Warn("iq with error: ", msg.Error.Code)
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
type query struct {
|
type query struct {
|
||||||
|
@ -44,10 +48,10 @@ func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
|
||||||
err := xml.Unmarshal(msg.Body, q)
|
err := xml.Unmarshal(msg.Body, q)
|
||||||
|
|
||||||
if q.XMLName.Space != messages.NSIQRegister || err != nil {
|
if q.XMLName.Space != messages.NSIQRegister || err != nil {
|
||||||
client.log.Warn("is no iq register: ", err)
|
client.Log.Warn("is no iq register: ", err)
|
||||||
return nil, client
|
return nil, client
|
||||||
}
|
}
|
||||||
client.out.Encode(&messages.IQ{
|
client.Out.Encode(&messages.IQ{
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
|
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
|
||||||
|
@ -61,36 +65,38 @@ func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type RegisterRequest struct {
|
type RegisterRequest struct {
|
||||||
Next State
|
Next state.State
|
||||||
|
database *database.State
|
||||||
|
domainRegisterAllowed utils.DomainRegisterAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process message
|
// Process message
|
||||||
func (state *RegisterRequest) Process(client *Client) (State, *Client) {
|
func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils.Client) {
|
||||||
client.log = client.log.WithField("state", "register request")
|
client.Log = client.Log.WithField("state", "register request")
|
||||||
client.log.Debug("running")
|
client.Log.Debug("running")
|
||||||
defer client.log.Debug("leave")
|
defer client.Log.Debug("leave")
|
||||||
|
|
||||||
if !client.DomainRegisterAllowed() {
|
if !state.domainRegisterAllowed(client.JID) {
|
||||||
client.log.Error("unpossible to reach this state, register on this domain is not allowed")
|
client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
|
||||||
return nil, client
|
return nil, client
|
||||||
}
|
}
|
||||||
|
|
||||||
element, err := client.Read()
|
element, err := client.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.log.Warn("unable to read: ", err)
|
client.Log.Warn("unable to read: ", err)
|
||||||
return nil, client
|
return nil, client
|
||||||
}
|
}
|
||||||
var msg messages.IQ
|
var msg messages.IQ
|
||||||
if err = client.in.DecodeElement(&msg, element); err != nil {
|
if err = client.In.DecodeElement(&msg, element); err != nil {
|
||||||
client.log.Warn("is no iq: ", err)
|
client.Log.Warn("is no iq: ", err)
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
if msg.Type != messages.IQTypeGet {
|
if msg.Type != messages.IQTypeGet {
|
||||||
client.log.Warn("is no get iq")
|
client.Log.Warn("is no get iq")
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
if msg.Error != nil {
|
if msg.Error != nil {
|
||||||
client.log.Warn("iq with error: ", msg.Error.Code)
|
client.Log.Warn("iq with error: ", msg.Error.Code)
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
type query struct {
|
type query struct {
|
||||||
|
@ -101,16 +107,16 @@ func (state *RegisterRequest) Process(client *Client) (State, *Client) {
|
||||||
q := &query{}
|
q := &query{}
|
||||||
err = xml.Unmarshal(msg.Body, q)
|
err = xml.Unmarshal(msg.Body, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.log.Warn("is no iq register: ", err)
|
client.Log.Warn("is no iq register: ", err)
|
||||||
return nil, client
|
return nil, client
|
||||||
}
|
}
|
||||||
|
|
||||||
client.jid.Local = q.Username
|
client.JID.Local = q.Username
|
||||||
client.log = client.log.WithField("jid", client.jid.Full())
|
client.Log = client.Log.WithField("jid", client.JID.Full())
|
||||||
account := model.NewAccount(client.jid, q.Password)
|
account := model.NewAccount(client.JID, q.Password)
|
||||||
err = client.Server.Database.AddAccount(account)
|
err = state.database.AddAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
client.out.Encode(&messages.IQ{
|
client.Out.Encode(&messages.IQ{
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
|
Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
|
||||||
|
@ -126,15 +132,14 @@ func (state *RegisterRequest) Process(client *Client) (State, *Client) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
client.log.Warn("database error: ", err)
|
client.Log.Warn("database error: ", err)
|
||||||
return state, client
|
return state, client
|
||||||
}
|
}
|
||||||
client.account = account
|
client.Out.Encode(&messages.IQ{
|
||||||
client.out.Encode(&messages.IQ{
|
|
||||||
Type: messages.IQTypeResult,
|
Type: messages.IQTypeResult,
|
||||||
ID: msg.ID,
|
ID: msg.ID,
|
||||||
})
|
})
|
||||||
|
|
||||||
client.log.Infof("registered client %s", client.jid.Bare())
|
client.Log.Infof("registered client %s", client.JID.Bare())
|
||||||
return state.Next, client
|
return state.Next, client
|
||||||
}
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
package toclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeResource() string {
|
||||||
|
var buf [16]byte
|
||||||
|
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
||||||
|
panic("Failed to read random bytes: " + err.Error())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%x", buf)
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
Log *log.Entry
|
||||||
|
|
||||||
|
Conn net.Conn
|
||||||
|
Out *xml.Encoder
|
||||||
|
In *xml.Decoder
|
||||||
|
|
||||||
|
JID *model.JID
|
||||||
|
account *model.Account
|
||||||
|
|
||||||
|
Messages chan interface{}
|
||||||
|
close chan interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(conn net.Conn, level log.Level) *Client {
|
||||||
|
logger := log.New()
|
||||||
|
logger.SetLevel(level)
|
||||||
|
client := &Client{
|
||||||
|
Conn: conn,
|
||||||
|
Log: log.NewEntry(logger),
|
||||||
|
In: xml.NewDecoder(conn),
|
||||||
|
Out: xml.NewEncoder(conn),
|
||||||
|
close: make(chan interface{}),
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) SetConnecting(conn net.Conn) {
|
||||||
|
client.Conn = conn
|
||||||
|
client.In = xml.NewDecoder(conn)
|
||||||
|
client.Out = xml.NewEncoder(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Read() (*xml.StartElement, error) {
|
||||||
|
for {
|
||||||
|
nextToken, err := client.In.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch nextToken.(type) {
|
||||||
|
case xml.StartElement:
|
||||||
|
element := nextToken.(xml.StartElement)
|
||||||
|
return &element, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) OnClose() <-chan interface{} {
|
||||||
|
return client.close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *Client) Close() {
|
||||||
|
client.close <- true
|
||||||
|
client.Conn.Close()
|
||||||
|
}
|
|
@ -1,29 +1,25 @@
|
||||||
package server
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/genofire/yaja/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Cookie is used to give a unique identifier to each request.
|
// Cookie is used to give a unique identifier to each request.
|
||||||
type Cookie uint64
|
type Cookie uint64
|
||||||
|
|
||||||
func createCookie() Cookie {
|
func CreateCookie() Cookie {
|
||||||
var buf [8]byte
|
var buf [8]byte
|
||||||
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
||||||
panic("Failed to read random bytes: " + err.Error())
|
panic("Failed to read random bytes: " + err.Error())
|
||||||
}
|
}
|
||||||
return Cookie(binary.LittleEndian.Uint64(buf[:]))
|
return Cookie(binary.LittleEndian.Uint64(buf[:]))
|
||||||
}
|
}
|
||||||
func createCookieString() string {
|
func CreateCookieString() string {
|
||||||
return fmt.Sprintf("%x", createCookie())
|
return fmt.Sprintf("%x", CreateCookie())
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeResource() string {
|
type DomainRegisterAllowed func(*model.JID) bool
|
||||||
var buf [16]byte
|
|
||||||
if _, err := rand.Reader.Read(buf[:]); err != nil {
|
|
||||||
panic("Failed to read random bytes: " + err.Error())
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%x", buf)
|
|
||||||
}
|
|
Reference in New Issue