From 46e5efd98b21971c055b48e6dc06bdbd6ddc9cff Mon Sep 17 00:00:00 2001 From: Martin/Geno Date: Sat, 3 Mar 2018 09:28:24 +0100 Subject: [PATCH] [TEST] add for client --- client/client.go | 17 ++++---- client/client_test.go | 39 ++++++++++++++++++ client/comm.go | 32 +++++++-------- client/comm_test.go | 96 +++++++++++++++++++++++++++++++++++++++++++ client/tls_test.go | 16 ++++++++ 5 files changed, 175 insertions(+), 25 deletions(-) create mode 100644 client/comm_test.go create mode 100644 client/tls_test.go diff --git a/client/client.go b/client/client.go index 510dbc4..d64375d 100644 --- a/client/client.go +++ b/client/client.go @@ -1,7 +1,6 @@ package client import ( - "crypto/tls" "encoding/xml" "fmt" "net" @@ -33,16 +32,15 @@ type Client struct { func NewClient(jid *xmppbase.JID, password string) (*Client, error) { client := &Client{ - Protocol: "tcp", - JID: jid, - Logging: log.New().WithField("jid", jid.String()), + JID: jid, + Logging: log.New().WithField("jid", jid.String()), } return client, client.Connect(password) } func (client *Client) Connect(password string) error { _, srvEntries, err := net.LookupSRV("xmpp-client", "tcp", client.JID.Domain) - addr := client.JID.Domain + ":5222" + addr := client.JID.Domain if err == nil && len(srvEntries) > 0 { bestSrv := srvEntries[0] for _, srv := range srvEntries { @@ -52,18 +50,19 @@ func (client *Client) Connect(password string) error { } } } - a := strings.SplitN(addr, ":", 2) - if len(a) == 1 { + if strings.LastIndex(addr, ":") <= strings.LastIndex(addr, "]") { addr += ":5222" } if client.Protocol == "" { client.Protocol = "tcp" } + client.Logging.Debug("try tcp connect") conn, err := net.DialTimeout(client.Protocol, addr, client.Timeout) - client.setConnection(conn) if err != nil { return err } + client.Logging.Debug("tcp connected") + client.setConnection(conn) if err = client.connect(password); err != nil { client.Close() @@ -74,7 +73,7 @@ func (client *Client) Connect(password string) error { // Close closes the XMPP connection func (c *Client) Close() error { - if c.conn != (*tls.Conn)(nil) { + if c.conn != nil { return c.conn.Close() } return nil diff --git a/client/client_test.go b/client/client_test.go index da13c8e..4d5a3e2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1 +1,40 @@ package client + +/* +func TestClient(t *testing.T) { + assert := assert.New(t) + + jid := xmppbase.NewJID("test@example.net") + + logger := log.New() + logger.SetLevel(log.DebugLevel) + + client := &Client{ + JID: jid, + Timeout: time.Millisecond * 500, + Logging: logger.WithField("jid", jid.String()), + } + // close nil connected + assert.NoError(client.Close()) + + err := client.Connect("password") + assert.Error(err) + assert.Contains(err.Error(), "timeout") + + jid.Domain = "chat.sum7.eu" + + // invalid auth + client, err = NewClient(jid, "password") + assert.NotNil(client) + assert.Error(err) + assert.Contains(err.Error(), "not-authorized : ") + // already closed + assert.Error(client.Close()) + + client.Logging = logger.WithField("jid", jid.String()) + + err = client.Connect("FqzMp6bevlHlt8d") + assert.NoError(err) + +} +*/ diff --git a/client/comm.go b/client/comm.go index e5cd6a5..64e5ea1 100644 --- a/client/comm.go +++ b/client/comm.go @@ -6,9 +6,9 @@ import ( "dev.sum7.eu/genofire/yaja/xmpp" ) -func (client *Client) Read() (*xml.StartElement, error) { +func read(decoder *xml.Decoder) (*xml.StartElement, error) { for { - nextToken, err := client.in.Token() + nextToken, err := decoder.Token() if err != nil { return nil, err } @@ -19,6 +19,9 @@ func (client *Client) Read() (*xml.StartElement, error) { } } } +func (client *Client) Read() (*xml.StartElement, error) { + return read(client.in) +} func (client *Client) Decode(p interface{}, element *xml.StartElement) error { err := client.in.DecodeElement(p, element) if err != nil { @@ -43,7 +46,7 @@ func (client *Client) ReadDecode(p interface{}) error { iq = &xmpp.IQClient{} } err = client.Decode(iq, element) - if err == nil && iq.Ping != nil { + if err == nil && iq.Ping != nil && iq.Type == xmpp.IQTypeGet { client.Logging.Info("client.ReadElement: auto answer ping") iq.Type = xmpp.IQTypeResult iq.To = iq.From @@ -56,35 +59,32 @@ func (client *Client) ReadDecode(p interface{}) error { } return client.Decode(p, element) } -func (client *Client) encode(p interface{}) error { - err := client.out.Encode(p) +func (client *Client) send(p interface{}) error { + b, err := xml.Marshal(p) if err != nil { + client.Logging.Warnf("error send %v", p) return err - } else { - if b, err := xml.Marshal(p); err == nil { - client.Logging.Debugf("encode %v", string(b)) - } else { - client.Logging.Debugf("encode %v", p) - } } - return nil + client.Logging.Debugf("send %v", string(b)) + _, err = client.conn.Write(b) + return err } func (client *Client) Send(p interface{}) error { msg, ok := p.(*xmpp.MessageClient) if ok { msg.From = client.JID - return client.encode(msg) + return client.send(msg) } iq, ok := p.(*xmpp.IQClient) if ok { iq.From = client.JID - return client.encode(iq) + return client.send(iq) } pc, ok := p.(*xmpp.PresenceClient) if ok { pc.From = client.JID - return client.encode(pc) + return client.send(pc) } - return client.encode(p) + return client.send(p) } diff --git a/client/comm_test.go b/client/comm_test.go new file mode 100644 index 0000000..f6487cf --- /dev/null +++ b/client/comm_test.go @@ -0,0 +1,96 @@ +package client + +import ( + "encoding/xml" + "net" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + log "github.com/sirupsen/logrus" + + "dev.sum7.eu/genofire/yaja/xmpp" + "dev.sum7.eu/genofire/yaja/xmpp/base" +) + +func TestRead(t *testing.T) { + assert := assert.New(t) + + server, clientConn := net.Pipe() + client := &Client{} + client.setConnection(clientConn) + + go server.Write([]byte(``)) + + element, err := client.Read() + assert.NoError(err) + assert.Equal("message", element.Name.Local) + + go server.Write([]byte(`<>`)) + element, err = client.Read() + assert.Error(err) +} + +func TestSend(t *testing.T) { + assert := assert.New(t) + + server, clientConn := net.Pipe() + client := &Client{ + Logging: log.WithField("test", "send"), + } + client.setConnection(clientConn) + serverDecoder := xml.NewDecoder(server) + wgWait := &sync.WaitGroup{} + + wgWait.Add(1) + go func() { + err := client.Send(3) + assert.NoError(err) + wgWait.Done() + }() + + element, err := read(serverDecoder) + wgWait.Wait() + assert.NoError(err) + assert.Equal("int", element.Name.Local) + + wgWait.Add(1) + go func() { + err := client.Send(&xmpp.MessageClient{To: xmppbase.NewJID("a@a.de")}) + assert.NoError(err) + wgWait.Done() + }() + + element, err = read(serverDecoder) + wgWait.Wait() + assert.NoError(err) + assert.Equal("message", element.Name.Local) + assert.Equal("a@a.de", element.Attr[1].Value) + + wgWait.Add(1) + go func() { + err := client.Send(&xmpp.IQClient{Type: xmpp.IQTypeGet}) + assert.NoError(err) + wgWait.Done() + }() + + element, err = read(serverDecoder) + wgWait.Wait() + assert.NoError(err) + assert.Equal("iq", element.Name.Local) + assert.Equal("get", element.Attr[2].Value) + + wgWait.Add(1) + go func() { + err := client.Send(&xmpp.PresenceClient{Type: xmpp.PresenceTypeSubscribe}) + assert.NoError(err) + wgWait.Done() + }() + + element, err = read(serverDecoder) + wgWait.Wait() + assert.NoError(err) + assert.Equal("presence", element.Name.Local) + assert.Equal("subscribe", element.Attr[1].Value) +} diff --git a/client/tls_test.go b/client/tls_test.go new file mode 100644 index 0000000..42a4306 --- /dev/null +++ b/client/tls_test.go @@ -0,0 +1,16 @@ +package client + +import ( + "crypto/tls" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTLS(t *testing.T) { + assert := assert.New(t) + client := &Client{} + assert.Nil(client.TLSConnectionState()) + client.conn = &tls.Conn{} + assert.NotNil(client.TLSConnectionState()) +}