sum7
/
yaja
Archived
1
0
Fork 0

improve client

This commit is contained in:
Martin/Geno 2018-02-11 19:35:32 +01:00
parent f4bc539cd7
commit 654d0306cf
No known key found for this signature in database
GPG Key ID: F0D39A37E925E941
14 changed files with 232 additions and 146 deletions

View File

@ -19,25 +19,27 @@ func (client *Client) auth(password string) error {
} }
//auth: //auth:
mechanism := "" mechanism := ""
challenge := &messages.SASLChallenge{}
response := &messages.SASLResponse{}
for _, m := range f.Mechanisms.Mechanism { for _, m := range f.Mechanisms.Mechanism {
if m == "PLAIN" { if m == "SCRAM-SHA-1" {
/*
mechanism = m mechanism = m
// Plain authentication: send base64-encoded \x00 user \x00 password. TODO
raw := "\x00" + client.JID.Local + "\x00" + password
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw))
fmt.Fprintf(client.conn, "<auth xmlns='%s' mechanism='PLAIN'>%s</auth>\n", messages.NSSASL, enc)
break break
*/
} }
if m == "DIGEST-MD5" { if m == "DIGEST-MD5" {
mechanism = m mechanism = m
// Digest-MD5 authentication // Digest-MD5 authentication
fmt.Fprintf(client.conn, "<auth xmlns='%s' mechanism='DIGEST-MD5'/>\n", messages.NSSASL) client.Out.Encode(&messages.SASLAuth{
var ch string Mechanism: m,
if err := client.ReadElement(&ch); err != nil { })
if err := client.ReadElement(challenge); err != nil {
return err return err
} }
b, err := base64.StdEncoding.DecodeString(string(ch)) b, err := base64.StdEncoding.DecodeString(challenge.Body)
if err != nil { if err != nil {
return err return err
} }
@ -62,29 +64,37 @@ func (client *Client) auth(password string) error {
message := "username=\"" + client.JID.Local + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr + message := "username=\"" + client.JID.Local + "\", realm=\"" + realm + "\", nonce=\"" + nonce + "\", cnonce=\"" + cnonceStr +
"\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset "\", nc=" + nonceCount + ", qop=" + qop + ", digest-uri=\"" + digestURI + "\", response=" + digest + ", charset=" + charset
fmt.Fprintf(client.conn, "<response xmlns='%s'>%s</response>\n", messages.NSSASL, base64.StdEncoding.EncodeToString([]byte(message))) response.Body = base64.StdEncoding.EncodeToString([]byte(message))
client.Out.Encode(response)
break
}
if m == "PLAIN" {
mechanism = m
// Plain authentication: send base64-encoded \x00 user \x00 password.
raw := "\x00" + client.JID.Local + "\x00" + password
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw))
client.Out.Encode(&messages.SASLAuth{
Mechanism: "PLAIN",
Body: string(enc),
})
err = client.ReadElement(&ch)
if err != nil {
return err
}
_, err = base64.StdEncoding.DecodeString(ch)
if err != nil {
return err
}
fmt.Fprintf(client.conn, "<response xmlns='%s'/>\n", messages.NSSASL)
break break
} }
} }
if mechanism == "" { if mechanism == "" {
return fmt.Errorf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism) return fmt.Errorf("PLAIN authentication is not an option: %s", f.Mechanisms.Mechanism)
} }
element, err := client.Read() element, err := client.Read()
if err != nil { if err != nil {
return err return err
} }
if element.Name.Local != "success" { fail := messages.SASLFailure{}
return errors.New("auth failed: " + element.Name.Local) if err := client.In.DecodeElement(&fail, element); err == nil {
return errors.New(messages.XMLChildrenString(fail) + " : " + fail.Body)
}
if err := client.In.DecodeElement(&messages.SASLSuccess{}, element); err != nil {
return errors.New("auth failed - with unexpected answer")
} }
return nil return nil
} }

View File

@ -11,7 +11,6 @@ import (
"dev.sum7.eu/genofire/yaja/messages" "dev.sum7.eu/genofire/yaja/messages"
"dev.sum7.eu/genofire/yaja/model" "dev.sum7.eu/genofire/yaja/model"
"dev.sum7.eu/genofire/yaja/server/utils"
) )
// Client holds XMPP connection opitons // Client holds XMPP connection opitons
@ -24,7 +23,7 @@ type Client struct {
} }
func NewClient(jid *model.JID, password string) (*Client, error) { func NewClient(jid *model.JID, password string) (*Client, error) {
return NewClientProtocolDuration(jid, password, "tcp", -1) return NewClientProtocolDuration(jid, password, "tcp", 0)
} }
func NewClientProtocolDuration(jid *model.JID, password string, proto string, timeout time.Duration) (*Client, error) { func NewClientProtocolDuration(jid *model.JID, password string, proto string, timeout time.Duration) (*Client, error) {
@ -43,12 +42,7 @@ func NewClientProtocolDuration(jid *model.JID, password string, proto string, ti
if len(a) == 1 { if len(a) == 1 {
addr += ":5222" addr += ":5222"
} }
var conn net.Conn conn, err := net.DialTimeout(proto, addr, timeout)
if timeout >= 0 {
conn, err = net.DialTimeout(proto, addr, timeout)
} else {
conn, err = net.Dial(proto, addr)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,12 +118,11 @@ func (client *Client) connect(password string) error {
if err := tlsconn.VerifyHostname(client.JID.Domain); err != nil { if err := tlsconn.VerifyHostname(client.JID.Domain); err != nil {
return err return err
} }
err := client.auth(password) if err := client.auth(password); err != nil {
if err != nil {
return err return err
} }
_, err = client.startStream()
if err != nil { if _, err := client.startStream(); err != nil {
return err return err
} }
// bind to resource // bind to resource
@ -137,23 +130,22 @@ func (client *Client) connect(password string) error {
if client.JID.Resource != "" { if client.JID.Resource != "" {
bind.Resource = client.JID.Resource bind.Resource = client.JID.Resource
} }
client.Out.Encode(&messages.IQClient{ if err := client.Out.Encode(&messages.IQClient{
Type: messages.IQTypeSet, Type: messages.IQTypeSet,
To: model.NewJID(client.JID.Domain),
From: client.JID, From: client.JID,
ID: utils.CreateCookieString(), To: model.NewJID(client.JID.Domain),
Bind: bind, Bind: bind,
}) }); err != nil {
return err
}
var iq messages.IQClient var iq messages.IQClient
if err := client.ReadElement(&iq); err != nil { if err := client.ReadElement(&iq); err != nil {
return err return err
} }
if iq.Error != nil { if iq.Error != nil {
if iq.Error.Type == messages.ErrorClientTypeCancel && iq.Error.ServiceUnavailable != nil { if iq.Error.ServiceUnavailable == nil {
//TODO binding service unavailable return errors.New(fmt.Sprintf("recv error on iq>bind: %s[%s]: %s -> %s -> %s", iq.Error.Code, iq.Error.Type, iq.Error.Text, messages.XMLChildrenString(iq.Error.StanzaErrorGroup), messages.XMLChildrenString(iq.Error.Other)))
} else {
return errors.New(fmt.Sprintf("recv error on iq>bind: %s[%s]: %s -> %v", iq.Error.Code, iq.Error.Type, iq.Error.Text, iq.Error.Other))
} }
} else if iq.Bind == nil { } else if iq.Bind == nil {
return errors.New("<iq> result missing <bind>") return errors.New("<iq> result missing <bind>")
@ -162,10 +154,8 @@ func (client *Client) connect(password string) error {
client.JID.Domain = iq.Bind.JID.Domain client.JID.Domain = iq.Bind.JID.Domain
client.JID.Resource = iq.Bind.JID.Resource client.JID.Resource = iq.Bind.JID.Resource
} else { } else {
return errors.New(fmt.Sprintf("%v", iq.Other)) return errors.New(messages.XMLChildrenString(iq.Other))
} }
// set status // set status
err = client.Send(&messages.PresenceClient{Show: messages.ShowTypeXA, Status: "online"}) return client.Send(&messages.PresenceClient{Show: messages.PresenceShowXA, Status: "online"})
return err
} }

View File

@ -5,7 +5,6 @@ import (
"log" "log"
"dev.sum7.eu/genofire/yaja/messages" "dev.sum7.eu/genofire/yaja/messages"
"dev.sum7.eu/genofire/yaja/server/utils"
) )
func (client *Client) Read() (*xml.StartElement, error) { func (client *Client) Read() (*xml.StartElement, error) {
@ -51,25 +50,16 @@ func (client *Client) Send(p interface{}) error {
msg, ok := p.(*messages.MessageClient) msg, ok := p.(*messages.MessageClient)
if ok { if ok {
msg.From = client.JID msg.From = client.JID
if msg.ID == "" {
msg.ID = utils.CreateCookieString()
}
return client.Out.Encode(msg) return client.Out.Encode(msg)
} }
iq, ok := p.(*messages.IQClient) iq, ok := p.(*messages.IQClient)
if ok { if ok {
iq.From = client.JID iq.From = client.JID
if iq.ID == "" {
iq.ID = utils.CreateCookieString()
}
return client.Out.Encode(iq) return client.Out.Encode(iq)
} }
pc, ok := p.(*messages.PresenceClient) pc, ok := p.(*messages.PresenceClient)
if ok { if ok {
pc.From = client.JID pc.From = client.JID
if pc.ID == "" {
pc.ID = utils.CreateCookieString()
}
return client.Out.Encode(pc) return client.Out.Encode(pc)
} }
return client.Out.Encode(p) return client.Out.Encode(p)

View File

@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"dev.sum7.eu/genofire/yaja/messages" "dev.sum7.eu/genofire/yaja/messages"
"dev.sum7.eu/genofire/yaja/server/utils"
) )
func (t *Tester) StartBot(status *Status) { func (t *Tester) StartBot(status *Status) {
@ -24,7 +23,7 @@ func (t *Tester) StartBot(status *Status) {
errMSG := &messages.StreamError{} errMSG := &messages.StreamError{}
err = status.client.In.DecodeElement(errMSG, element) err = status.client.In.DecodeElement(errMSG, element)
if err == nil { if err == nil {
logCTX.Errorf("recv stream error: %s: %v", errMSG.Text, errMSG.Any) logCTX.Errorf("recv stream error: %s: %s", errMSG.Text, messages.XMLChildrenString(errMSG.Any))
status.client.Close() status.client.Close()
status.Login = false status.Login = false
return return
@ -40,7 +39,7 @@ func (t *Tester) StartBot(status *Status) {
iq.From = status.client.JID iq.From = status.client.JID
status.client.Out.Encode(iq) status.client.Out.Encode(iq)
} else { } else {
logCTX.Warnf("unsupport iq recv: %v", iq) logCTX.Warnf("recv iq unsupport: %s", messages.XMLChildrenString(iq))
} }
continue continue
} }
@ -51,7 +50,7 @@ func (t *Tester) StartBot(status *Status) {
sender := pres.From sender := pres.From
logPres := logCTX.WithField("from", sender.Full()) logPres := logCTX.WithField("from", sender.Full())
if pres.Type == messages.PresenceTypeSubscribe { if pres.Type == messages.PresenceTypeSubscribe {
logPres.Debugf("recv subscribe") logPres.Debugf("recv presence subscribe")
pres.Type = messages.PresenceTypeSubscribed pres.Type = messages.PresenceTypeSubscribed
pres.To = sender pres.To = sender
pres.From = nil pres.From = nil
@ -59,17 +58,19 @@ func (t *Tester) StartBot(status *Status) {
logPres.Debugf("accept new subscribe") logPres.Debugf("accept new subscribe")
pres.Type = messages.PresenceTypeSubscribe pres.Type = messages.PresenceTypeSubscribe
pres.ID = utils.CreateCookieString() pres.ID = ""
status.client.Out.Encode(pres) status.client.Out.Encode(pres)
logPres.Info("request also subscribe") logPres.Info("request also subscribe")
} else if pres.Type == messages.PresenceTypeSubscribed { } else if pres.Type == messages.PresenceTypeSubscribed {
logPres.Info("recv accepted subscribe") logPres.Info("recv presence accepted subscribe")
} else if pres.Type == messages.PresenceTypeUnsubscribe { } else if pres.Type == messages.PresenceTypeUnsubscribe {
logPres.Info("recv remove subscribe") logPres.Info("recv presence remove subscribe")
} else if pres.Type == messages.PresenceTypeUnsubscribed { } else if pres.Type == messages.PresenceTypeUnsubscribed {
logPres.Info("recv removed subscribe") logPres.Info("recv presence removed subscribe")
} else if pres.Type == messages.PresenceTypeUnavailable {
logPres.Debug("recv presence unavailable")
} else { } else {
logCTX.Warnf("unsupported presence recv: %v", pres) logCTX.Warnf("recv presence unsupported: %s -> %s", pres.Type, messages.XMLChildrenString(pres))
} }
continue continue
} }
@ -82,7 +83,13 @@ func (t *Tester) StartBot(status *Status) {
} }
logCTX = logCTX.WithField("from", msg.From.Full()).WithField("msg-recv", msg.Body) logCTX = logCTX.WithField("from", msg.From.Full()).WithField("msg-recv", msg.Body)
if msg.Error != nil { if msg.Error != nil {
logCTX.Debugf("recv msg with error %s[%s]: %s -> %v -> %v", msg.Error.Code, msg.Error.Type, msg.Error.Text, msg.Error.StanzaErrorGroup, msg.Error.Other) if msg.Error.Type == "auth" {
logCTX.Warnf("recv msg with error not auth")
status.Login = false
status.client.Close()
return
}
logCTX.Debugf("recv msg with error %s[%s]: %s -> %s -> %s", msg.Error.Code, msg.Error.Type, msg.Error.Text, messages.XMLChildrenString(msg.Error.StanzaErrorGroup), messages.XMLChildrenString(msg.Error.Other))
continue continue
} }

View File

@ -9,7 +9,6 @@ import (
"dev.sum7.eu/genofire/yaja/client" "dev.sum7.eu/genofire/yaja/client"
"dev.sum7.eu/genofire/yaja/messages" "dev.sum7.eu/genofire/yaja/messages"
"dev.sum7.eu/genofire/yaja/model" "dev.sum7.eu/genofire/yaja/model"
"dev.sum7.eu/genofire/yaja/server/utils"
) )
type Tester struct { type Tester struct {
@ -149,12 +148,12 @@ func (t *Tester) CheckStatus() {
logCTXTo.Debug("could not recv msg") logCTXTo.Debug("could not recv msg")
} }
} }
msg = utils.CreateCookieString() msg = messages.CreateCookieString()
logCTXTo = logCTXTo.WithField("msg-send", msg) logCTXTo = logCTXTo.WithField("msg-send", msg)
own.client.Send(&messages.MessageClient{ own.client.Send(&messages.MessageClient{
Body: "checkmsg " + msg, Body: "checkmsg " + msg,
Type: messages.ChatTypeChat, Type: messages.MessageTypeChat,
To: s.JID, To: s.JID,
}) })
own.MessageForConnection[s.JID.Bare()] = msg own.MessageForConnection[s.JID.Bare()] = msg

35
messages/message.go Normal file
View File

@ -0,0 +1,35 @@
package messages
import (
"encoding/xml"
"dev.sum7.eu/genofire/yaja/model"
)
type MessageType string
const (
MessageTypeChat MessageType = "chat"
MessageTypeGroupchat MessageType = "groupchat"
MessageTypeError MessageType = "error"
MessageTypeHeadline MessageType = "headline"
MessageTypeNormal MessageType = "normal"
)
// MessageClient element
type MessageClient struct {
XMLName xml.Name `xml:"jabber:client message"`
From *model.JID `xml:"from,attr,omitempty"`
ID string `xml:"id,attr,omitempty"`
To *model.JID `xml:"to,attr,omitempty"`
Type MessageType `xml:"type,attr,omitempty"`
Lang string `xml:"lang,attr,omitempty"`
Subject string `xml:"subject"`
Body string `xml:"body"`
Thread string `xml:"thread"`
// Any hasn't matched element
Other []XMLElement `xml:",any"`
Delay *Delay `xml:"delay"`
Error *ErrorClient
}

View File

@ -6,15 +6,6 @@ import (
"dev.sum7.eu/genofire/yaja/model" "dev.sum7.eu/genofire/yaja/model"
) )
type XMLElement struct {
XMLName xml.Name
InnerXML string `xml:",innerxml"`
}
type Delay struct {
Stamp string `xml:"stamp,attr"`
}
type PresenceType string type PresenceType string
const ( const (
@ -27,13 +18,13 @@ const (
PresenceTypeError PresenceType = "error" PresenceTypeError PresenceType = "error"
) )
type ShowType string type PresenceShow string
const ( const (
ShowTypeAway ShowType = "away" PresenceShowAway PresenceShow = "away"
ShowTypeChat ShowType = "chat" PresenceShowChat PresenceShow = "chat"
ShowTypeDND ShowType = "dnd" PresenceShowDND PresenceShow = "dnd"
ShowTypeXA ShowType = "xa" PresenceShowXA PresenceShow = "xa"
) )
// PresenceClient element // PresenceClient element
@ -45,7 +36,7 @@ type PresenceClient struct {
Type PresenceType `xml:"type,attr,omitempty"` Type PresenceType `xml:"type,attr,omitempty"`
Lang string `xml:"lang,attr,omitempty"` Lang string `xml:"lang,attr,omitempty"`
Show ShowType `xml:"show,omitempty"` // away, chat, dnd, xa Show PresenceShow `xml:"show,omitempty"` // away, chat, dnd, xa
Status string `xml:"status,omitempty"` // sb []clientText Status string `xml:"status,omitempty"` // sb []clientText
Priority string `xml:"priority,omitempty"` Priority string `xml:"priority,omitempty"`
// Caps *ClientCaps `xml:"c"` // Caps *ClientCaps `xml:"c"`
@ -53,31 +44,3 @@ type PresenceClient struct {
Error *ErrorClient Error *ErrorClient
} }
type ChatType string
const (
ChatTypeChat ChatType = "chat"
ChatTypeGroupchat ChatType = "groupchat"
ChatTypeError ChatType = "error"
ChatTypeHeadline ChatType = "headline"
ChatTypeNormal ChatType = "normal"
)
// MessageClient element
type MessageClient struct {
XMLName xml.Name `xml:"jabber:client message"`
From *model.JID `xml:"from,attr,omitempty"`
ID string `xml:"id,attr,omitempty"`
To *model.JID `xml:"to,attr,omitempty"`
Type ChatType `xml:"type,attr,omitempty"`
Lang string `xml:"lang,attr,omitempty"`
Subject string `xml:"subject"`
Body string `xml:"body"`
Thread string `xml:"thread"`
// Any hasn't matched element
Other []XMLElement `xml:",any"`
Delay *Delay `xml:"delay"`
Error *ErrorClient
}

View File

@ -1,6 +1,14 @@
package messages package messages
import "encoding/xml" import (
"encoding/xml"
)
// RFC 3920 C.4 SASL name space
type SASLMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// SASLAuth element // SASLAuth element
type SASLAuth struct { type SASLAuth struct {
@ -9,8 +17,44 @@ type SASLAuth struct {
Body string `xml:",chardata"` Body string `xml:",chardata"`
} }
// RFC 3920 C.4 SASL name space // SASLChallenge element
type SASLMechanisms struct { type SASLChallenge struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl challenge"`
Mechanism []string `xml:"mechanism"` Body string `xml:",chardata"`
}
// SASLResponse element
type SASLResponse struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl response"`
Body string `xml:",chardata"`
}
// SASLSuccess element
type SASLSuccess struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"`
Body string `xml:",chardata"`
}
// SASLAbout element
type SASLAbout struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl abort"`
}
// SASLFailure element
type SASLFailure struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl failure"`
Aborted *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl aborted"`
AccountDisabled *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl account-disabled"`
CredentialsExpired *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl credentials-expired"`
EncryptionRequired *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl encryption-required"`
IncorrectEncoding *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl incorrect-encoding"`
InvalidAuthzid *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl invalid-authzid"`
InvalidMechanism *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl invalid-mechanism"`
MalformedRequest *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl malformed-request"`
MechanismTooWeak *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanism-too-weak"`
NotAuthorized *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl not-authorized"`
TemporaryAuthFailure *xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl temporary-auth-failure"`
Body string `xml:",chardata"`
} }

66
messages/utils.go Normal file
View File

@ -0,0 +1,66 @@
package messages
import (
"crypto/rand"
"encoding/binary"
"encoding/xml"
"fmt"
"reflect"
)
type Delay struct {
Stamp string `xml:"stamp,attr"`
}
type XMLElement struct {
XMLName xml.Name
InnerXML string `xml:",innerxml"`
}
func XMLChildrenString(o interface{}) (result string) {
first := true
val := reflect.ValueOf(o)
if val.Kind() == reflect.Interface && !val.IsNil() {
elm := val.Elem()
if elm.Kind() == reflect.Ptr && !elm.IsNil() && elm.Elem().Kind() == reflect.Ptr {
val = elm
}
}
if val.Kind() != reflect.Struct {
return
}
// struct
for i := 0; i < val.NumField(); i++ {
valueField := val.Field(i)
if valueField.Kind() == reflect.Interface && !valueField.IsNil() {
elm := valueField.Elem()
if elm.Kind() == reflect.Ptr && !elm.IsNil() && elm.Elem().Kind() == reflect.Ptr {
valueField = elm
}
}
if xmlElement, ok := valueField.Interface().(*xml.Name); ok && xmlElement != nil {
if first {
first = false
} else {
result += ", "
}
result += xmlElement.Local
}
}
return
}
// Cookie is used to give a unique identifier to each request.
type Cookie uint64
func CreateCookie() Cookie {
var buf [8]byte
if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error())
}
return Cookie(binary.LittleEndian.Uint64(buf[:]))
}
func CreateCookieString() string {
return fmt.Sprintf("%x", CreateCookie())
}

View File

@ -51,7 +51,7 @@ func (iex IQExtensions) Process(element *xml.StartElement, client *utils.Client)
// not extensions found // not extensions found
if count != 1 { if count != 1 {
log.Debugf("%s - %s: %v", msg.XMLName.Space, msg.Type, msg.Other) log.Debugf("%s - %s: %s", msg.XMLName.Space, msg.Type, messages.XMLChildrenString(msg.Other))
} }
return true return true

View File

@ -44,7 +44,7 @@ func (state *Start) Process() State {
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?> fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`, <stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream) messages.CreateCookie(), messages.NSClient, messages.NSStream)
fmt.Fprintf(state.Client.Conn, `<stream:features> fmt.Fprintf(state.Client.Conn, `<stream:features>
<starttls xmlns='%s'> <starttls xmlns='%s'>

View File

@ -73,7 +73,7 @@ func (state *TLSStream) Process() state.State {
<mechanism>PLAIN</mechanism> <mechanism>PLAIN</mechanism>
</mechanisms> </mechanisms>
</stream:features>`, </stream:features>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream, messages.CreateCookie(), messages.NSClient, messages.NSStream,
messages.NSSASL, messages.NSFeaturesIQRegister) messages.NSSASL, messages.NSFeaturesIQRegister)
} else { } else {
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?> fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
@ -83,7 +83,7 @@ func (state *TLSStream) Process() state.State {
<mechanism>PLAIN</mechanism> <mechanism>PLAIN</mechanism>
</mechanisms> </mechanisms>
</stream:features>`, </stream:features>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream, messages.CreateCookie(), messages.NSClient, messages.NSStream,
messages.NSSASL) messages.NSSASL)
} }
@ -174,7 +174,7 @@ func (state *AuthedStart) Process() state.State {
<required/> <required/>
</bind> </bind>
</stream:features>`, </stream:features>`,
messages.NSStream, state.Client.JID.Domain, utils.CreateCookie(), messages.NSClient, messages.NSStream, state.Client.JID.Domain, messages.CreateCookie(), messages.NSClient,
messages.NSBind) messages.NSBind)
return state.Next return state.Next

View File

@ -97,7 +97,7 @@ func (state *TLSStream) Process() state.State {
</mechanisms> </mechanisms>
<bidi xmlns='urn:xmpp:features:bidi'/> <bidi xmlns='urn:xmpp:features:bidi'/>
</stream:features>`, </stream:features>`,
utils.CreateCookie(), messages.NSClient, messages.NSStream, messages.CreateCookie(), messages.NSClient, messages.NSStream,
messages.NSSASL) messages.NSSASL)
return state.Next return state.Next

View File

@ -1,25 +1,7 @@
package utils package utils
import ( import (
"crypto/rand"
"encoding/binary"
"fmt"
"dev.sum7.eu/genofire/yaja/model" "dev.sum7.eu/genofire/yaja/model"
) )
// Cookie is used to give a unique identifier to each request.
type Cookie uint64
func CreateCookie() Cookie {
var buf [8]byte
if _, err := rand.Reader.Read(buf[:]); err != nil {
panic("Failed to read random bytes: " + err.Error())
}
return Cookie(binary.LittleEndian.Uint64(buf[:]))
}
func CreateCookieString() string {
return fmt.Sprintf("%x", CreateCookie())
}
type DomainRegisterAllowed func(*model.JID) bool type DomainRegisterAllowed func(*model.JID) bool