sum7
/
yaja
Archived
1
0
Fork 0

[TEST] add for client

This commit is contained in:
Martin/Geno 2018-03-03 09:28:24 +01:00
parent 4732e0d292
commit 46e5efd98b
No known key found for this signature in database
GPG Key ID: 9D7D3C6BFF600C6A
5 changed files with 175 additions and 25 deletions

View File

@ -1,7 +1,6 @@
package client package client
import ( import (
"crypto/tls"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"net" "net"
@ -33,7 +32,6 @@ type Client struct {
func NewClient(jid *xmppbase.JID, password string) (*Client, error) { func NewClient(jid *xmppbase.JID, password string) (*Client, error) {
client := &Client{ client := &Client{
Protocol: "tcp",
JID: jid, JID: jid,
Logging: log.New().WithField("jid", jid.String()), Logging: log.New().WithField("jid", jid.String()),
} }
@ -42,7 +40,7 @@ func NewClient(jid *xmppbase.JID, password string) (*Client, error) {
} }
func (client *Client) Connect(password string) error { func (client *Client) Connect(password string) error {
_, srvEntries, err := net.LookupSRV("xmpp-client", "tcp", client.JID.Domain) _, srvEntries, err := net.LookupSRV("xmpp-client", "tcp", client.JID.Domain)
addr := client.JID.Domain + ":5222" addr := client.JID.Domain
if err == nil && len(srvEntries) > 0 { if err == nil && len(srvEntries) > 0 {
bestSrv := srvEntries[0] bestSrv := srvEntries[0]
for _, srv := range srvEntries { for _, srv := range srvEntries {
@ -52,18 +50,19 @@ func (client *Client) Connect(password string) error {
} }
} }
} }
a := strings.SplitN(addr, ":", 2) if strings.LastIndex(addr, ":") <= strings.LastIndex(addr, "]") {
if len(a) == 1 {
addr += ":5222" addr += ":5222"
} }
if client.Protocol == "" { if client.Protocol == "" {
client.Protocol = "tcp" client.Protocol = "tcp"
} }
client.Logging.Debug("try tcp connect")
conn, err := net.DialTimeout(client.Protocol, addr, client.Timeout) conn, err := net.DialTimeout(client.Protocol, addr, client.Timeout)
client.setConnection(conn)
if err != nil { if err != nil {
return err return err
} }
client.Logging.Debug("tcp connected")
client.setConnection(conn)
if err = client.connect(password); err != nil { if err = client.connect(password); err != nil {
client.Close() client.Close()
@ -74,7 +73,7 @@ func (client *Client) Connect(password string) error {
// Close closes the XMPP connection // Close closes the XMPP connection
func (c *Client) Close() error { func (c *Client) Close() error {
if c.conn != (*tls.Conn)(nil) { if c.conn != nil {
return c.conn.Close() return c.conn.Close()
} }
return nil return nil

View File

@ -1 +1,40 @@
package client package client
/*
func TestClient(t *testing.T) {
assert := assert.New(t)
jid := xmppbase.NewJID("test@example.net")
logger := log.New()
logger.SetLevel(log.DebugLevel)
client := &Client{
JID: jid,
Timeout: time.Millisecond * 500,
Logging: logger.WithField("jid", jid.String()),
}
// close nil connected
assert.NoError(client.Close())
err := client.Connect("password")
assert.Error(err)
assert.Contains(err.Error(), "timeout")
jid.Domain = "chat.sum7.eu"
// invalid auth
client, err = NewClient(jid, "password")
assert.NotNil(client)
assert.Error(err)
assert.Contains(err.Error(), "not-authorized : ")
// already closed
assert.Error(client.Close())
client.Logging = logger.WithField("jid", jid.String())
err = client.Connect("FqzMp6bevlHlt8d")
assert.NoError(err)
}
*/

View File

@ -6,9 +6,9 @@ import (
"dev.sum7.eu/genofire/yaja/xmpp" "dev.sum7.eu/genofire/yaja/xmpp"
) )
func (client *Client) Read() (*xml.StartElement, error) { func read(decoder *xml.Decoder) (*xml.StartElement, error) {
for { for {
nextToken, err := client.in.Token() nextToken, err := decoder.Token()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -19,6 +19,9 @@ func (client *Client) Read() (*xml.StartElement, error) {
} }
} }
} }
func (client *Client) Read() (*xml.StartElement, error) {
return read(client.in)
}
func (client *Client) Decode(p interface{}, element *xml.StartElement) error { func (client *Client) Decode(p interface{}, element *xml.StartElement) error {
err := client.in.DecodeElement(p, element) err := client.in.DecodeElement(p, element)
if err != nil { if err != nil {
@ -43,7 +46,7 @@ func (client *Client) ReadDecode(p interface{}) error {
iq = &xmpp.IQClient{} iq = &xmpp.IQClient{}
} }
err = client.Decode(iq, element) err = client.Decode(iq, element)
if err == nil && iq.Ping != nil { if err == nil && iq.Ping != nil && iq.Type == xmpp.IQTypeGet {
client.Logging.Info("client.ReadElement: auto answer ping") client.Logging.Info("client.ReadElement: auto answer ping")
iq.Type = xmpp.IQTypeResult iq.Type = xmpp.IQTypeResult
iq.To = iq.From iq.To = iq.From
@ -56,35 +59,32 @@ func (client *Client) ReadDecode(p interface{}) error {
} }
return client.Decode(p, element) return client.Decode(p, element)
} }
func (client *Client) encode(p interface{}) error { func (client *Client) send(p interface{}) error {
err := client.out.Encode(p) b, err := xml.Marshal(p)
if err != nil { if err != nil {
client.Logging.Warnf("error send %v", p)
return err return err
} else {
if b, err := xml.Marshal(p); err == nil {
client.Logging.Debugf("encode %v", string(b))
} else {
client.Logging.Debugf("encode %v", p)
} }
} client.Logging.Debugf("send %v", string(b))
return nil _, err = client.conn.Write(b)
return err
} }
func (client *Client) Send(p interface{}) error { func (client *Client) Send(p interface{}) error {
msg, ok := p.(*xmpp.MessageClient) msg, ok := p.(*xmpp.MessageClient)
if ok { if ok {
msg.From = client.JID msg.From = client.JID
return client.encode(msg) return client.send(msg)
} }
iq, ok := p.(*xmpp.IQClient) iq, ok := p.(*xmpp.IQClient)
if ok { if ok {
iq.From = client.JID iq.From = client.JID
return client.encode(iq) return client.send(iq)
} }
pc, ok := p.(*xmpp.PresenceClient) pc, ok := p.(*xmpp.PresenceClient)
if ok { if ok {
pc.From = client.JID pc.From = client.JID
return client.encode(pc) return client.send(pc)
} }
return client.encode(p) return client.send(p)
} }

96
client/comm_test.go Normal file
View File

@ -0,0 +1,96 @@
package client
import (
"encoding/xml"
"net"
"sync"
"testing"
"github.com/stretchr/testify/assert"
log "github.com/sirupsen/logrus"
"dev.sum7.eu/genofire/yaja/xmpp"
"dev.sum7.eu/genofire/yaja/xmpp/base"
)
func TestRead(t *testing.T) {
assert := assert.New(t)
server, clientConn := net.Pipe()
client := &Client{}
client.setConnection(clientConn)
go server.Write([]byte(`<message>`))
element, err := client.Read()
assert.NoError(err)
assert.Equal("message", element.Name.Local)
go server.Write([]byte(`<>`))
element, err = client.Read()
assert.Error(err)
}
func TestSend(t *testing.T) {
assert := assert.New(t)
server, clientConn := net.Pipe()
client := &Client{
Logging: log.WithField("test", "send"),
}
client.setConnection(clientConn)
serverDecoder := xml.NewDecoder(server)
wgWait := &sync.WaitGroup{}
wgWait.Add(1)
go func() {
err := client.Send(3)
assert.NoError(err)
wgWait.Done()
}()
element, err := read(serverDecoder)
wgWait.Wait()
assert.NoError(err)
assert.Equal("int", element.Name.Local)
wgWait.Add(1)
go func() {
err := client.Send(&xmpp.MessageClient{To: xmppbase.NewJID("a@a.de")})
assert.NoError(err)
wgWait.Done()
}()
element, err = read(serverDecoder)
wgWait.Wait()
assert.NoError(err)
assert.Equal("message", element.Name.Local)
assert.Equal("a@a.de", element.Attr[1].Value)
wgWait.Add(1)
go func() {
err := client.Send(&xmpp.IQClient{Type: xmpp.IQTypeGet})
assert.NoError(err)
wgWait.Done()
}()
element, err = read(serverDecoder)
wgWait.Wait()
assert.NoError(err)
assert.Equal("iq", element.Name.Local)
assert.Equal("get", element.Attr[2].Value)
wgWait.Add(1)
go func() {
err := client.Send(&xmpp.PresenceClient{Type: xmpp.PresenceTypeSubscribe})
assert.NoError(err)
wgWait.Done()
}()
element, err = read(serverDecoder)
wgWait.Wait()
assert.NoError(err)
assert.Equal("presence", element.Name.Local)
assert.Equal("subscribe", element.Attr[1].Value)
}

16
client/tls_test.go Normal file
View File

@ -0,0 +1,16 @@
package client
import (
"crypto/tls"
"testing"
"github.com/stretchr/testify/assert"
)
func TestTLS(t *testing.T) {
assert := assert.New(t)
client := &Client{}
assert.Nil(client.TLSConnectionState())
client.conn = &tls.Conn{}
assert.NotNil(client.TLSConnectionState())
}