sum7
/
yaja
Archived
1
0
Fork 0

restructur code in packages

This commit is contained in:
Martin Geno 2017-12-16 23:20:46 +01:00
parent a079961c8b
commit 1aceea7133
No known key found for this signature in database
GPG Key ID: F0D39A37E925E941
17 changed files with 636 additions and 491 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/genofire/yaja/database" "github.com/genofire/yaja/database"
"github.com/genofire/yaja/model/config" "github.com/genofire/yaja/model/config"
"github.com/genofire/yaja/server/extension"
"github.com/genofire/golang-lib/file" "github.com/genofire/golang-lib/file"
"github.com/genofire/golang-lib/worker" "github.com/genofire/golang-lib/worker"
@ -29,6 +30,7 @@ var (
statesaveWorker *worker.Worker statesaveWorker *worker.Worker
srv *server.Server srv *server.Server
certs *tls.Config certs *tls.Config
extensions []extension.Extension
) )
// serverCmd represents the serve command // serverCmd represents the serve command
@ -81,6 +83,7 @@ var serverCmd = &cobra.Command{
LoggingClient: configData.Logging.LevelClient, LoggingClient: configData.Logging.LevelClient,
RegisterEnable: configData.Register.Enable, RegisterEnable: configData.Register.Enable,
RegisterDomains: configData.Register.Domains, RegisterDomains: configData.Register.Domains,
Extensions: extensions,
} }
go statesaveWorker.Start() go statesaveWorker.Start()
@ -161,6 +164,7 @@ func reload() {
LoggingClient: configNewData.Logging.LevelClient, LoggingClient: configNewData.Logging.LevelClient,
RegisterEnable: configNewData.Register.Enable, RegisterEnable: configNewData.Register.Enable,
RegisterDomains: configNewData.Register.Domains, RegisterDomains: configNewData.Register.Domains,
Extensions: extensions,
} }
log.Warn("reloading need a restart:") log.Warn("reloading need a restart:")
go newServer.Start() go newServer.Start()
@ -176,4 +180,6 @@ func reload() {
func init() { func init() {
RootCmd.AddCommand(serverCmd) RootCmd.AddCommand(serverCmd)
serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") serverCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file")
extensions = append(extensions, &extension.Message{}, &extension.Roster{Database: db})
} }

View File

@ -29,9 +29,10 @@ func (d *Domain) UpdateAccount(a *Account) error {
} }
type Account struct { type Account struct {
Local string `json:"-"` Local string `json:"-"`
Domain *Domain `json:"-"` Domain *Domain `json:"-"`
Password string `json:"password"` Password string `json:"password"`
Roster map[string]*Buddy `json:"roster"`
} }
func NewAccount(jid *JID, password string) *Account { func NewAccount(jid *JID, password string) *Account {

17
model/buddy.go Normal file
View File

@ -0,0 +1,17 @@
package model
const (
SubscriptionNone = iota
SubscriptionTo
SubscriptionFrom
SubscriptionBoth
AskNone = iota
AskSubscribe
)
type Buddy struct {
Name string `json:"name"`
Groups []string `json:"groups"`
Subscription int `json:"subscription"`
Ask int `json:"ask"`
}

View File

@ -1,76 +0,0 @@
package server
import (
"encoding/xml"
"net"
"github.com/genofire/yaja/model"
log "github.com/sirupsen/logrus"
)
type Client struct {
log *log.Entry
Conn net.Conn
out *xml.Encoder
in *xml.Decoder
Server *Server
jid *model.JID
account *model.Account
messages chan interface{}
close chan interface{}
}
func NewClient(conn net.Conn, srv *Server) *Client {
logger := log.New()
logger.SetLevel(srv.LoggingClient)
client := &Client{
Conn: conn,
Server: srv,
log: log.NewEntry(logger),
in: xml.NewDecoder(conn),
out: xml.NewEncoder(conn),
}
return client
}
func (client *Client) NewConnecting(conn net.Conn) {
client.Conn = conn
client.in = xml.NewDecoder(conn)
client.out = xml.NewEncoder(conn)
}
func (client *Client) Read() (*xml.StartElement, error) {
for {
nextToken, err := client.in.Token()
if err != nil {
return nil, err
}
switch nextToken.(type) {
case xml.StartElement:
element := nextToken.(xml.StartElement)
return &element, nil
}
}
}
func (client *Client) DomainRegisterAllowed() bool {
if client.jid.Domain == "" {
return false
}
for _, domain := range client.Server.RegisterDomains {
if domain == client.jid.Domain {
return !client.Server.RegisterEnable
}
}
return client.Server.RegisterEnable
}
func (client *Client) Close() {
client.close <- true
client.Conn.Close()
}

11
server/extension/main.go Normal file
View File

@ -0,0 +1,11 @@
package extension
import (
"encoding/xml"
"github.com/genofire/yaja/server/utils"
)
type Extension interface {
Process(*xml.StartElement, *utils.Client) bool
}

View File

@ -0,0 +1,15 @@
package extension
import (
"encoding/xml"
"github.com/genofire/yaja/server/utils"
)
type Message struct {
Extension
}
func (m *Message) Process(element *xml.StartElement, client *utils.Client) bool {
return false
}

View File

@ -0,0 +1,27 @@
package extension
import (
"encoding/xml"
"github.com/genofire/yaja/database"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/server/utils"
)
type Roster struct {
Extension
Database *database.State
}
func (r *Roster) Process(element *xml.StartElement, client *utils.Client) bool {
var msg messages.IQ
if err := client.In.DecodeElement(&msg, element); err != nil {
client.Log.Warn("is no iq: ", err)
return false
}
if msg.Type != messages.IQTypeGet {
client.Log.Warn("is no get iq")
return false
}
return true
}

View File

@ -5,6 +5,10 @@ import (
"net" "net"
"github.com/genofire/yaja/database" "github.com/genofire/yaja/database"
"github.com/genofire/yaja/model"
"github.com/genofire/yaja/server/extension"
"github.com/genofire/yaja/server/toclient"
"github.com/genofire/yaja/server/utils"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
) )
@ -16,8 +20,9 @@ type Server struct {
ServerAddr []string ServerAddr []string
Database *database.State Database *database.State
LoggingClient log.Level LoggingClient log.Level
RegisterEnable bool `toml:"enable"` RegisterEnable bool
RegisterDomains []string `toml:"domains"` RegisterDomains []string
Extensions []extension.Extension
} }
func (srv *Server) Start() { func (srv *Server) Start() {
@ -68,13 +73,13 @@ func (srv *Server) handleServer(conn net.Conn) {
func (srv *Server) handleClient(conn net.Conn) { func (srv *Server) handleClient(conn net.Conn) {
log.Info("new client connection:", conn.RemoteAddr()) log.Info("new client connection:", conn.RemoteAddr())
client := NewClient(conn, srv) client := utils.NewClient(conn, srv.LoggingClient)
state := ConnectionStartup() state := toclient.ConnectionStartup(srv.Database, srv.TLSConfig, srv.TLSManager, srv.DomainRegisterAllowed)
for { for {
state, client = state.Process(client) state, client = state.Process(client)
if state == nil { if state == nil {
client.log.Info("disconnect") client.Log.Info("disconnect")
client.Close() client.Close()
//s.DisconnectBus <- Disconnect{Jid: client.jid} //s.DisconnectBus <- Disconnect{Jid: client.jid}
return return
@ -83,6 +88,19 @@ func (srv *Server) handleClient(conn net.Conn) {
} }
} }
func (srv *Server) DomainRegisterAllowed(jid *model.JID) bool {
if jid.Domain == "" {
return false
}
for _, domain := range srv.RegisterDomains {
if domain == jid.Domain {
return !srv.RegisterEnable
}
}
return srv.RegisterEnable
}
func (srv *Server) Close() { func (srv *Server) Close() {
} }

View File

@ -1,6 +0,0 @@
package server
// State processes the stream and moves to the next state
type State interface {
Process(client *Client) (State, *Client)
}

125
server/state/connect.go Normal file
View File

@ -0,0 +1,125 @@
package state
import (
"crypto/tls"
"fmt"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/model"
"github.com/genofire/yaja/server/utils"
"golang.org/x/crypto/acme/autocert"
)
// ConnectionStartup return steps through TCP TLS state
func ConnectionStartup(after State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager) State {
tlsupgrade := &TLSUpgrade{
Next: after,
tlsconfig: tlsconfig,
tlsmgmt: tlsmgmt,
}
stream := &Start{Next: tlsupgrade}
return stream
}
// Start state
type Start struct {
Next State
}
// Process message
func (state *Start) Process(client *utils.Client) (State, *utils.Client) {
client.Log = client.Log.WithField("state", "stream")
client.Log.Debug("running")
defer client.Log.Debug("leave")
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.Log.Warn("is no stream")
return state, client
}
for _, attr := range element.Attr {
if attr.Name.Local == "to" {
client.JID = &model.JID{Domain: attr.Value}
client.Log = client.Log.WithField("jid", client.JID.Full())
}
}
if client.JID == nil {
client.Log.Warn("no 'to' domain readed")
return nil, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<starttls xmlns='%s'>
<required/>
</starttls>
</stream:features>`,
messages.NSStream)
return state.Next, client
}
// TLSUpgrade state
type TLSUpgrade struct {
Next State
tlsconfig *tls.Config
tlsmgmt *autocert.Manager
}
// Process message
func (state *TLSUpgrade) Process(client *utils.Client) (State, *utils.Client) {
client.Log = client.Log.WithField("state", "tls upgrade")
client.Log.Debug("running")
defer client.Log.Debug("leave")
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
client.Log.Warn("is no starttls")
return state, client
}
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
// perform the TLS handshake
var tlsConfig *tls.Config
if m := state.tlsmgmt; m != nil {
var cert *tls.Certificate
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.JID.Domain})
if err != nil {
client.Log.Warn("no cert in tls manger found: ", err)
return nil, client
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{*cert},
}
}
if tlsConfig == nil {
tlsConfig = state.tlsconfig
if tlsConfig != nil {
tlsConfig.ServerName = client.JID.Domain
} else {
client.Log.Warn("no tls config found: ", err)
return nil, client
}
}
tlsConn := tls.Server(client.Conn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
client.Log.Warn("unable to tls handshake: ", err)
return nil, client
}
// restart the Connection
client.SetConnecting(tlsConn)
return state.Next, client
}

8
server/state/state.go Normal file
View File

@ -0,0 +1,8 @@
package state
import "github.com/genofire/yaja/server/utils"
// State processes the stream and moves to the next state
type State interface {
Process(client *utils.Client) (State, *utils.Client)
}

View File

@ -1,353 +0,0 @@
package server
import (
"crypto/tls"
"encoding/base64"
"encoding/xml"
"fmt"
"strings"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/model"
)
// ConnectionStartup return steps through TCP TLS state
func ConnectionStartup() State {
receiving := &ReceivingClient{}
sending := &SendingClient{Next: receiving}
authedstream := &AuthedStream{Next: sending}
authedstart := &AuthedStart{Next: authedstream}
tlsauth := &SASLAuth{Next: authedstart}
tlsstream := &TLSStream{Next: tlsauth}
tlsupgrade := &TLSUpgrade{Next: tlsstream}
stream := &Start{Next: tlsupgrade}
return stream
}
// Start state
type Start struct {
Next State
}
// Process message
func (state *Start) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "stream")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.log.Warn("is no stream")
return state, client
}
for _, attr := range element.Attr {
if attr.Name.Local == "to" {
client.jid = &model.JID{Domain: attr.Value}
client.log = client.log.WithField("jid", client.jid.Full())
}
}
if client.jid == nil {
client.log.Warn("no 'to' domain readed")
return nil, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
createCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<starttls xmlns='%s'>
<required/>
</starttls>
</stream:features>`,
messages.NSStream)
return state.Next, client
}
// TLSUpgrade state
type TLSUpgrade struct {
Next State
}
// Process message
func (state *TLSUpgrade) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "tls upgrade")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
client.log.Warn("is no starttls")
return state, client
}
fmt.Fprintf(client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
// perform the TLS handshake
var tlsConfig *tls.Config
if m := client.Server.TLSManager; m != nil {
var cert *tls.Certificate
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: client.jid.Domain})
if err != nil {
client.log.Warn("no cert in tls manger found: ", err)
return nil, client
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{*cert},
}
}
if tlsConfig == nil {
tlsConfig = client.Server.TLSConfig
if tlsConfig != nil {
tlsConfig.ServerName = client.jid.Domain
} else {
client.log.Warn("no tls config found: ", err)
return nil, client
}
}
tlsConn := tls.Server(client.Conn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
client.log.Warn("unable to tls handshake: ", err)
return nil, client
}
// restart the Connection
client.NewConnecting(tlsConn)
return state.Next, client
}
// TLSStream state
type TLSStream struct {
Next State
}
// Process messages
func (state *TLSStream) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "tls stream")
client.log.Debug("running")
defer client.log.Debug("leave")
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.log.Warn("is no stream")
return state, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
createCookie(), messages.NSClient, messages.NSStream)
if client.DomainRegisterAllowed() {
fmt.Fprintf(client.Conn, `<stream:features>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
<register xmlns='%s'/>
</stream:features>`,
messages.NSSASL, messages.NSFeaturesIQRegister)
} else {
fmt.Fprintf(client.Conn, `<stream:features>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
</stream:features>`,
messages.NSSASL)
}
return state.Next, client
}
// SASLAuth state
type SASLAuth struct {
Next State
}
// Process messages
func (state *SASLAuth) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "sasl auth")
client.log.Debug("running")
defer client.log.Debug("leave")
// read the full auth stanza
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
var auth messages.SASLAuth
if err = client.in.DecodeElement(&auth, element); err != nil {
client.log.Info("start substate for registration")
return &RegisterFormRequest{
element: element,
Next: &RegisterRequest{
Next: state.Next,
},
}, client
}
data, err := base64.StdEncoding.DecodeString(auth.Body)
if err != nil {
client.log.Warn("body decode: ", err)
return nil, client
}
info := strings.Split(string(data), "\x00")
// should check that info[1] starts with client.jid
client.jid.Local = info[1]
client.log = client.log.WithField("jid", client.jid.Full())
success, err := client.Server.Database.Authenticate(client.jid, info[2])
if err != nil {
client.log.Warn("auth: ", err)
return nil, client
}
if success {
client.log.Info("success auth")
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
return state.Next, client
}
client.log.Warn("failed auth")
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
return nil, client
}
// AuthedStart state
type AuthedStart struct {
Next State
}
// Process messages
func (state *AuthedStart) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "authed started")
client.log.Debug("running")
defer client.log.Debug("leave")
_, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
createCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<bind xmlns='%s'/>
</stream:features>`,
messages.NSBind)
return state.Next, client
}
// AuthedStream state
type AuthedStream struct {
Next State
}
// Process messages
func (state *AuthedStream) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "authed stream")
client.log.Debug("running")
defer client.log.Debug("leave")
// check that it's a bind request
// read bind request
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
var msg messages.IQ
if err = client.in.DecodeElement(&msg, element); err != nil {
client.log.Warn("is no iq: ", err)
return nil, client
}
if msg.Type != messages.IQTypeSet {
client.log.Warn("is no set iq")
return nil, client
}
if msg.Error != nil {
client.log.Warn("iq with error: ", msg.Error.Code)
return nil, client
}
type query struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
Resource string `xml:"resource"`
}
q := &query{}
err = xml.Unmarshal(msg.Body, q)
if err != nil {
client.log.Warn("is no iq bind: ", err)
return nil, client
}
if q.Resource == "" {
client.jid.Resource = makeResource()
} else {
client.jid.Resource = q.Resource
}
client.log = client.log.WithField("jid", client.jid.Full())
client.out.Encode(&messages.IQ{
Type: messages.IQTypeResult,
ID: msg.ID,
Body: []byte(fmt.Sprintf(
`<bind xmlns='%s'>
<jid>%s</jid>
</bind>`,
messages.NSBind, client.jid.Full())),
})
return state.Next, client
}
// SendingClient state
type SendingClient struct {
Next State
}
// Process messages
func (state *SendingClient) Process(client *Client) (State, *Client) {
client.log = client.log.WithField("state", "normal")
client.log.Debug("sending")
// sending
go func() {
select {
case msg := <-client.messages:
err := client.out.Encode(msg)
client.log.Info(err)
case <-client.close:
return
}
}()
client.log.Debug("receiving")
return state.Next, client
}
// ReceivingClient state
type ReceivingClient struct {
}
// Process messages
func (state *ReceivingClient) Process(client *Client) (State, *Client) {
element, err := client.Read()
if err != nil {
client.log.Warn("unable to read: ", err)
return nil, client
}
/*
for _, extension := range client.Server.Extensions {
extension.Process(element, client)
}*/
client.log.Debug(element)
return state, client
}

272
server/toclient/connect.go Normal file
View File

@ -0,0 +1,272 @@
package toclient
import (
"crypto/tls"
"encoding/base64"
"encoding/xml"
"fmt"
"strings"
"github.com/genofire/yaja/database"
"github.com/genofire/yaja/messages"
"github.com/genofire/yaja/server/extension"
"github.com/genofire/yaja/server/state"
"github.com/genofire/yaja/server/utils"
"golang.org/x/crypto/acme/autocert"
)
// ConnectionStartup return steps through TCP TLS state
func ConnectionStartup(db *database.State, tlsconfig *tls.Config, tlsmgmt *autocert.Manager, registerAllowed utils.DomainRegisterAllowed) state.State {
receiving := &ReceivingClient{}
sending := &SendingClient{Next: receiving}
authedstream := &AuthedStream{Next: sending}
authedstart := &AuthedStart{Next: authedstream}
tlsauth := &SASLAuth{
Next: authedstart,
database: db,
domainRegisterAllowed: registerAllowed,
}
tlsstream := &TLSStream{
Next: tlsauth,
domainRegisterAllowed: registerAllowed,
}
return state.ConnectionStartup(tlsstream, tlsconfig, tlsmgmt)
}
// TLSStream state
type TLSStream struct {
Next state.State
domainRegisterAllowed utils.DomainRegisterAllowed
}
// Process messages
func (state *TLSStream) Process(client *utils.Client) (state.State, *utils.Client) {
client.Log = client.Log.WithField("state", "tls stream")
client.Log.Debug("running")
defer client.Log.Debug("leave")
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
client.Log.Warn("is no stream")
return state, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream)
if state.domainRegisterAllowed(client.JID) {
fmt.Fprintf(client.Conn, `<stream:features>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
<register xmlns='%s'/>
</stream:features>`,
messages.NSSASL, messages.NSFeaturesIQRegister)
} else {
fmt.Fprintf(client.Conn, `<stream:features>
<mechanisms xmlns='%s'>
<mechanism>PLAIN</mechanism>
</mechanisms>
</stream:features>`,
messages.NSSASL)
}
return state.Next, client
}
// SASLAuth state
type SASLAuth struct {
Next state.State
database *database.State
domainRegisterAllowed utils.DomainRegisterAllowed
}
// Process messages
func (state *SASLAuth) Process(client *utils.Client) (state.State, *utils.Client) {
client.Log = client.Log.WithField("state", "sasl auth")
client.Log.Debug("running")
defer client.Log.Debug("leave")
// read the full auth stanza
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
var auth messages.SASLAuth
if err = client.In.DecodeElement(&auth, element); err != nil {
client.Log.Info("start substate for registration")
return &RegisterFormRequest{
element: element,
domainRegisterAllowed: state.domainRegisterAllowed,
Next: &RegisterRequest{
domainRegisterAllowed: state.domainRegisterAllowed,
database: state.database,
Next: state.Next,
},
}, client
}
data, err := base64.StdEncoding.DecodeString(auth.Body)
if err != nil {
client.Log.Warn("body decode: ", err)
return nil, client
}
info := strings.Split(string(data), "\x00")
// should check that info[1] starts with client.JID
client.JID.Local = info[1]
client.Log = client.Log.WithField("jid", client.JID.Full())
success, err := state.database.Authenticate(client.JID, info[2])
if err != nil {
client.Log.Warn("auth: ", err)
return nil, client
}
if success {
client.Log.Info("success auth")
fmt.Fprintf(client.Conn, "<success xmlns='%s'/>", messages.NSSASL)
return state.Next, client
}
client.Log.Warn("failed auth")
fmt.Fprintf(client.Conn, "<failure xmlns='%s'><not-authorized/></failure>", messages.NSSASL)
return nil, client
}
// AuthedStart state
type AuthedStart struct {
Next state.State
}
// Process messages
func (state *AuthedStart) Process(client *utils.Client) (state.State, *utils.Client) {
client.Log = client.Log.WithField("state", "authed started")
client.Log.Debug("running")
defer client.Log.Debug("leave")
_, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
fmt.Fprintf(client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(client.Conn, `<stream:features>
<bind xmlns='%s'/>
</stream:features>`,
messages.NSBind)
return state.Next, client
}
// AuthedStream state
type AuthedStream struct {
Next state.State
}
// Process messages
func (state *AuthedStream) Process(client *utils.Client) (state.State, *utils.Client) {
client.Log = client.Log.WithField("state", "authed stream")
client.Log.Debug("running")
defer client.Log.Debug("leave")
// check that it's a bind request
// read bind request
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
var msg messages.IQ
if err = client.In.DecodeElement(&msg, element); err != nil {
client.Log.Warn("is no iq: ", err)
return nil, client
}
if msg.Type != messages.IQTypeSet {
client.Log.Warn("is no set iq")
return nil, client
}
if msg.Error != nil {
client.Log.Warn("iq with error: ", msg.Error.Code)
return nil, client
}
type query struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
Resource string `xml:"resource"`
}
q := &query{}
err = xml.Unmarshal(msg.Body, q)
if err != nil {
client.Log.Warn("is no iq bind: ", err)
return nil, client
}
if q.Resource == "" {
client.JID.Resource = makeResource()
} else {
client.JID.Resource = q.Resource
}
client.Log = client.Log.WithField("jid", client.JID.Full())
client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult,
ID: msg.ID,
Body: []byte(fmt.Sprintf(
`<bind xmlns='%s'>
<jid>%s</jid>
</bind>`,
messages.NSBind, client.JID.Full())),
})
return state.Next, client
}
// SendingClient state
type SendingClient struct {
Next state.State
}
// Process messages
func (state *SendingClient) Process(client *utils.Client) (state.State, *utils.Client) {
client.Log = client.Log.WithField("state", "normal")
client.Log.Debug("sending")
// sending
go func() {
select {
case msg := <-client.Messages:
err := client.Out.Encode(msg)
client.Log.Info(err)
case <-client.OnClose():
return
}
}()
client.Log.Debug("receiving")
return state.Next, client
}
// ReceivingClient state
type ReceivingClient struct {
Extensions []extension.Extension
}
// Process messages
func (state *ReceivingClient) Process(client *utils.Client) (state.State, *utils.Client) {
element, err := client.Read()
if err != nil {
client.Log.Warn("unable to read: ", err)
return nil, client
}
count := 0
for _, extension := range state.Extensions {
if extension.Process(element, client) {
count++
}
}
if count != 1 {
client.Log.WithField("extension", count).Debug(element)
}
return state, client
}

View File

@ -1,40 +1,44 @@
package server package toclient
import ( import (
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"github.com/genofire/yaja/database"
"github.com/genofire/yaja/messages" "github.com/genofire/yaja/messages"
"github.com/genofire/yaja/model" "github.com/genofire/yaja/model"
"github.com/genofire/yaja/server/state"
"github.com/genofire/yaja/server/utils"
) )
type RegisterFormRequest struct { type RegisterFormRequest struct {
Next State Next state.State
element *xml.StartElement domainRegisterAllowed utils.DomainRegisterAllowed
element *xml.StartElement
} }
// Process message // Process message
func (state *RegisterFormRequest) Process(client *Client) (State, *Client) { func (state *RegisterFormRequest) Process(client *utils.Client) (state.State, *utils.Client) {
client.log = client.log.WithField("state", "register form request") client.Log = client.Log.WithField("state", "register form request")
client.log.Debug("running") client.Log.Debug("running")
defer client.log.Debug("leave") defer client.Log.Debug("leave")
if !client.DomainRegisterAllowed() { if !state.domainRegisterAllowed(client.JID) {
client.log.Error("unpossible to reach this state, register on this domain is not allowed") client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
return nil, client return nil, client
} }
var msg messages.IQ var msg messages.IQ
if err := client.in.DecodeElement(&msg, state.element); err != nil { if err := client.In.DecodeElement(&msg, state.element); err != nil {
client.log.Warn("is no iq: ", err) client.Log.Warn("is no iq: ", err)
return state, client return state, client
} }
if msg.Type != messages.IQTypeGet { if msg.Type != messages.IQTypeGet {
client.log.Warn("is no get iq") client.Log.Warn("is no get iq")
return state, client return state, client
} }
if msg.Error != nil { if msg.Error != nil {
client.log.Warn("iq with error: ", msg.Error.Code) client.Log.Warn("iq with error: ", msg.Error.Code)
return state, client return state, client
} }
type query struct { type query struct {
@ -44,10 +48,10 @@ func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
err := xml.Unmarshal(msg.Body, q) err := xml.Unmarshal(msg.Body, q)
if q.XMLName.Space != messages.NSIQRegister || err != nil { if q.XMLName.Space != messages.NSIQRegister || err != nil {
client.log.Warn("is no iq register: ", err) client.Log.Warn("is no iq register: ", err)
return nil, client return nil, client
} }
client.out.Encode(&messages.IQ{ client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
ID: msg.ID, ID: msg.ID,
Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions> Body: []byte(fmt.Sprintf(`<query xmlns='%s'><instructions>
@ -61,36 +65,38 @@ func (state *RegisterFormRequest) Process(client *Client) (State, *Client) {
} }
type RegisterRequest struct { type RegisterRequest struct {
Next State Next state.State
database *database.State
domainRegisterAllowed utils.DomainRegisterAllowed
} }
// Process message // Process message
func (state *RegisterRequest) Process(client *Client) (State, *Client) { func (state *RegisterRequest) Process(client *utils.Client) (state.State, *utils.Client) {
client.log = client.log.WithField("state", "register request") client.Log = client.Log.WithField("state", "register request")
client.log.Debug("running") client.Log.Debug("running")
defer client.log.Debug("leave") defer client.Log.Debug("leave")
if !client.DomainRegisterAllowed() { if !state.domainRegisterAllowed(client.JID) {
client.log.Error("unpossible to reach this state, register on this domain is not allowed") client.Log.Error("unpossible to reach this state, register on this domain is not allowed")
return nil, client return nil, client
} }
element, err := client.Read() element, err := client.Read()
if err != nil { if err != nil {
client.log.Warn("unable to read: ", err) client.Log.Warn("unable to read: ", err)
return nil, client return nil, client
} }
var msg messages.IQ var msg messages.IQ
if err = client.in.DecodeElement(&msg, element); err != nil { if err = client.In.DecodeElement(&msg, element); err != nil {
client.log.Warn("is no iq: ", err) client.Log.Warn("is no iq: ", err)
return state, client return state, client
} }
if msg.Type != messages.IQTypeGet { if msg.Type != messages.IQTypeGet {
client.log.Warn("is no get iq") client.Log.Warn("is no get iq")
return state, client return state, client
} }
if msg.Error != nil { if msg.Error != nil {
client.log.Warn("iq with error: ", msg.Error.Code) client.Log.Warn("iq with error: ", msg.Error.Code)
return state, client return state, client
} }
type query struct { type query struct {
@ -101,16 +107,16 @@ func (state *RegisterRequest) Process(client *Client) (State, *Client) {
q := &query{} q := &query{}
err = xml.Unmarshal(msg.Body, q) err = xml.Unmarshal(msg.Body, q)
if err != nil { if err != nil {
client.log.Warn("is no iq register: ", err) client.Log.Warn("is no iq register: ", err)
return nil, client return nil, client
} }
client.jid.Local = q.Username client.JID.Local = q.Username
client.log = client.log.WithField("jid", client.jid.Full()) client.Log = client.Log.WithField("jid", client.JID.Full())
account := model.NewAccount(client.jid, q.Password) account := model.NewAccount(client.JID, q.Password)
err = client.Server.Database.AddAccount(account) err = state.database.AddAccount(account)
if err != nil { if err != nil {
client.out.Encode(&messages.IQ{ client.Out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
ID: msg.ID, ID: msg.ID,
Body: []byte(fmt.Sprintf(`<query xmlns='%s'> Body: []byte(fmt.Sprintf(`<query xmlns='%s'>
@ -126,15 +132,14 @@ func (state *RegisterRequest) Process(client *Client) (State, *Client) {
}, },
}, },
}) })
client.log.Warn("database error: ", err) client.Log.Warn("database error: ", err)
return state, client return state, client
} }
client.account = account client.Out.Encode(&messages.IQ{
client.out.Encode(&messages.IQ{
Type: messages.IQTypeResult, Type: messages.IQTypeResult,
ID: msg.ID, ID: msg.ID,
}) })
client.log.Infof("registered client %s", client.jid.Bare()) client.Log.Infof("registered client %s", client.JID.Bare())
return state.Next, client return state.Next, client
} }

14
server/toclient/utils.go Normal file
View File

@ -0,0 +1,14 @@
package toclient
import (
"crypto/rand"
"fmt"
)
func makeResource() string {
var buf [16]byte
if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error())
}
return fmt.Sprintf("%x", buf)
}

65
server/utils/client.go Normal file
View File

@ -0,0 +1,65 @@
package utils
import (
"encoding/xml"
"net"
"github.com/genofire/yaja/model"
log "github.com/sirupsen/logrus"
)
type Client struct {
Log *log.Entry
Conn net.Conn
Out *xml.Encoder
In *xml.Decoder
JID *model.JID
account *model.Account
Messages chan interface{}
close chan interface{}
}
func NewClient(conn net.Conn, level log.Level) *Client {
logger := log.New()
logger.SetLevel(level)
client := &Client{
Conn: conn,
Log: log.NewEntry(logger),
In: xml.NewDecoder(conn),
Out: xml.NewEncoder(conn),
close: make(chan interface{}),
}
return client
}
func (client *Client) SetConnecting(conn net.Conn) {
client.Conn = conn
client.In = xml.NewDecoder(conn)
client.Out = xml.NewEncoder(conn)
}
func (client *Client) Read() (*xml.StartElement, error) {
for {
nextToken, err := client.In.Token()
if err != nil {
return nil, err
}
switch nextToken.(type) {
case xml.StartElement:
element := nextToken.(xml.StartElement)
return &element, nil
}
}
}
func (client *Client) OnClose() <-chan interface{} {
return client.close
}
func (client *Client) Close() {
client.close <- true
client.Conn.Close()
}

View File

@ -1,29 +1,25 @@
package server package utils
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/genofire/yaja/model"
) )
// Cookie is used to give a unique identifier to each request. // Cookie is used to give a unique identifier to each request.
type Cookie uint64 type Cookie uint64
func createCookie() Cookie { func CreateCookie() Cookie {
var buf [8]byte var buf [8]byte
if _, err := rand.Reader.Read(buf[:]); err != nil { if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error()) panic("Failed to read random bytes: " + err.Error())
} }
return Cookie(binary.LittleEndian.Uint64(buf[:])) return Cookie(binary.LittleEndian.Uint64(buf[:]))
} }
func createCookieString() string { func CreateCookieString() string {
return fmt.Sprintf("%x", createCookie()) return fmt.Sprintf("%x", CreateCookie())
} }
func makeResource() string { type DomainRegisterAllowed func(*model.JID) bool
var buf [16]byte
if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error())
}
return fmt.Sprintf("%x", buf)
}