[TEST] add for client
This commit is contained in:
parent
4732e0d292
commit
46e5efd98b
|
@ -1,7 +1,6 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -33,7 +32,6 @@ type Client struct {
|
|||
|
||||
func NewClient(jid *xmppbase.JID, password string) (*Client, error) {
|
||||
client := &Client{
|
||||
Protocol: "tcp",
|
||||
JID: jid,
|
||||
Logging: log.New().WithField("jid", jid.String()),
|
||||
}
|
||||
|
@ -42,7 +40,7 @@ func NewClient(jid *xmppbase.JID, password string) (*Client, error) {
|
|||
}
|
||||
func (client *Client) Connect(password string) error {
|
||||
_, srvEntries, err := net.LookupSRV("xmpp-client", "tcp", client.JID.Domain)
|
||||
addr := client.JID.Domain + ":5222"
|
||||
addr := client.JID.Domain
|
||||
if err == nil && len(srvEntries) > 0 {
|
||||
bestSrv := srvEntries[0]
|
||||
for _, srv := range srvEntries {
|
||||
|
@ -52,18 +50,19 @@ func (client *Client) Connect(password string) error {
|
|||
}
|
||||
}
|
||||
}
|
||||
a := strings.SplitN(addr, ":", 2)
|
||||
if len(a) == 1 {
|
||||
if strings.LastIndex(addr, ":") <= strings.LastIndex(addr, "]") {
|
||||
addr += ":5222"
|
||||
}
|
||||
if client.Protocol == "" {
|
||||
client.Protocol = "tcp"
|
||||
}
|
||||
client.Logging.Debug("try tcp connect")
|
||||
conn, err := net.DialTimeout(client.Protocol, addr, client.Timeout)
|
||||
client.setConnection(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.Logging.Debug("tcp connected")
|
||||
client.setConnection(conn)
|
||||
|
||||
if err = client.connect(password); err != nil {
|
||||
client.Close()
|
||||
|
@ -74,7 +73,7 @@ func (client *Client) Connect(password string) error {
|
|||
|
||||
// Close closes the XMPP connection
|
||||
func (c *Client) Close() error {
|
||||
if c.conn != (*tls.Conn)(nil) {
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -1 +1,40 @@
|
|||
package client
|
||||
|
||||
/*
|
||||
func TestClient(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
jid := xmppbase.NewJID("test@example.net")
|
||||
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.DebugLevel)
|
||||
|
||||
client := &Client{
|
||||
JID: jid,
|
||||
Timeout: time.Millisecond * 500,
|
||||
Logging: logger.WithField("jid", jid.String()),
|
||||
}
|
||||
// close nil connected
|
||||
assert.NoError(client.Close())
|
||||
|
||||
err := client.Connect("password")
|
||||
assert.Error(err)
|
||||
assert.Contains(err.Error(), "timeout")
|
||||
|
||||
jid.Domain = "chat.sum7.eu"
|
||||
|
||||
// invalid auth
|
||||
client, err = NewClient(jid, "password")
|
||||
assert.NotNil(client)
|
||||
assert.Error(err)
|
||||
assert.Contains(err.Error(), "not-authorized : ")
|
||||
// already closed
|
||||
assert.Error(client.Close())
|
||||
|
||||
client.Logging = logger.WithField("jid", jid.String())
|
||||
|
||||
err = client.Connect("FqzMp6bevlHlt8d")
|
||||
assert.NoError(err)
|
||||
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
"dev.sum7.eu/genofire/yaja/xmpp"
|
||||
)
|
||||
|
||||
func (client *Client) Read() (*xml.StartElement, error) {
|
||||
func read(decoder *xml.Decoder) (*xml.StartElement, error) {
|
||||
for {
|
||||
nextToken, err := client.in.Token()
|
||||
nextToken, err := decoder.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -19,6 +19,9 @@ func (client *Client) Read() (*xml.StartElement, error) {
|
|||
}
|
||||
}
|
||||
}
|
||||
func (client *Client) Read() (*xml.StartElement, error) {
|
||||
return read(client.in)
|
||||
}
|
||||
func (client *Client) Decode(p interface{}, element *xml.StartElement) error {
|
||||
err := client.in.DecodeElement(p, element)
|
||||
if err != nil {
|
||||
|
@ -43,7 +46,7 @@ func (client *Client) ReadDecode(p interface{}) error {
|
|||
iq = &xmpp.IQClient{}
|
||||
}
|
||||
err = client.Decode(iq, element)
|
||||
if err == nil && iq.Ping != nil {
|
||||
if err == nil && iq.Ping != nil && iq.Type == xmpp.IQTypeGet {
|
||||
client.Logging.Info("client.ReadElement: auto answer ping")
|
||||
iq.Type = xmpp.IQTypeResult
|
||||
iq.To = iq.From
|
||||
|
@ -56,35 +59,32 @@ func (client *Client) ReadDecode(p interface{}) error {
|
|||
}
|
||||
return client.Decode(p, element)
|
||||
}
|
||||
func (client *Client) encode(p interface{}) error {
|
||||
err := client.out.Encode(p)
|
||||
func (client *Client) send(p interface{}) error {
|
||||
b, err := xml.Marshal(p)
|
||||
if err != nil {
|
||||
client.Logging.Warnf("error send %v", p)
|
||||
return err
|
||||
} else {
|
||||
if b, err := xml.Marshal(p); err == nil {
|
||||
client.Logging.Debugf("encode %v", string(b))
|
||||
} else {
|
||||
client.Logging.Debugf("encode %v", p)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
client.Logging.Debugf("send %v", string(b))
|
||||
_, err = client.conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (client *Client) Send(p interface{}) error {
|
||||
msg, ok := p.(*xmpp.MessageClient)
|
||||
if ok {
|
||||
msg.From = client.JID
|
||||
return client.encode(msg)
|
||||
return client.send(msg)
|
||||
}
|
||||
iq, ok := p.(*xmpp.IQClient)
|
||||
if ok {
|
||||
iq.From = client.JID
|
||||
return client.encode(iq)
|
||||
return client.send(iq)
|
||||
}
|
||||
pc, ok := p.(*xmpp.PresenceClient)
|
||||
if ok {
|
||||
pc.From = client.JID
|
||||
return client.encode(pc)
|
||||
return client.send(pc)
|
||||
}
|
||||
return client.encode(p)
|
||||
return client.send(p)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"dev.sum7.eu/genofire/yaja/xmpp"
|
||||
"dev.sum7.eu/genofire/yaja/xmpp/base"
|
||||
)
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
server, clientConn := net.Pipe()
|
||||
client := &Client{}
|
||||
client.setConnection(clientConn)
|
||||
|
||||
go server.Write([]byte(`<message>`))
|
||||
|
||||
element, err := client.Read()
|
||||
assert.NoError(err)
|
||||
assert.Equal("message", element.Name.Local)
|
||||
|
||||
go server.Write([]byte(`<>`))
|
||||
element, err = client.Read()
|
||||
assert.Error(err)
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
server, clientConn := net.Pipe()
|
||||
client := &Client{
|
||||
Logging: log.WithField("test", "send"),
|
||||
}
|
||||
client.setConnection(clientConn)
|
||||
serverDecoder := xml.NewDecoder(server)
|
||||
wgWait := &sync.WaitGroup{}
|
||||
|
||||
wgWait.Add(1)
|
||||
go func() {
|
||||
err := client.Send(3)
|
||||
assert.NoError(err)
|
||||
wgWait.Done()
|
||||
}()
|
||||
|
||||
element, err := read(serverDecoder)
|
||||
wgWait.Wait()
|
||||
assert.NoError(err)
|
||||
assert.Equal("int", element.Name.Local)
|
||||
|
||||
wgWait.Add(1)
|
||||
go func() {
|
||||
err := client.Send(&xmpp.MessageClient{To: xmppbase.NewJID("a@a.de")})
|
||||
assert.NoError(err)
|
||||
wgWait.Done()
|
||||
}()
|
||||
|
||||
element, err = read(serverDecoder)
|
||||
wgWait.Wait()
|
||||
assert.NoError(err)
|
||||
assert.Equal("message", element.Name.Local)
|
||||
assert.Equal("a@a.de", element.Attr[1].Value)
|
||||
|
||||
wgWait.Add(1)
|
||||
go func() {
|
||||
err := client.Send(&xmpp.IQClient{Type: xmpp.IQTypeGet})
|
||||
assert.NoError(err)
|
||||
wgWait.Done()
|
||||
}()
|
||||
|
||||
element, err = read(serverDecoder)
|
||||
wgWait.Wait()
|
||||
assert.NoError(err)
|
||||
assert.Equal("iq", element.Name.Local)
|
||||
assert.Equal("get", element.Attr[2].Value)
|
||||
|
||||
wgWait.Add(1)
|
||||
go func() {
|
||||
err := client.Send(&xmpp.PresenceClient{Type: xmpp.PresenceTypeSubscribe})
|
||||
assert.NoError(err)
|
||||
wgWait.Done()
|
||||
}()
|
||||
|
||||
element, err = read(serverDecoder)
|
||||
wgWait.Wait()
|
||||
assert.NoError(err)
|
||||
assert.Equal("presence", element.Name.Local)
|
||||
assert.Equal("subscribe", element.Attr[1].Value)
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
client := &Client{}
|
||||
assert.Nil(client.TLSConnectionState())
|
||||
client.conn = &tls.Conn{}
|
||||
assert.NotNil(client.TLSConnectionState())
|
||||
}
|
Reference in New Issue