[TEST] add for client
This commit is contained in:
parent
4732e0d292
commit
46e5efd98b
|
@ -1,7 +1,6 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
@ -33,16 +32,15 @@ type Client struct {
|
||||||
|
|
||||||
func NewClient(jid *xmppbase.JID, password string) (*Client, error) {
|
func NewClient(jid *xmppbase.JID, password string) (*Client, error) {
|
||||||
client := &Client{
|
client := &Client{
|
||||||
Protocol: "tcp",
|
JID: jid,
|
||||||
JID: jid,
|
Logging: log.New().WithField("jid", jid.String()),
|
||||||
Logging: log.New().WithField("jid", jid.String()),
|
|
||||||
}
|
}
|
||||||
return client, client.Connect(password)
|
return client, client.Connect(password)
|
||||||
|
|
||||||
}
|
}
|
||||||
func (client *Client) Connect(password string) error {
|
func (client *Client) Connect(password string) error {
|
||||||
_, srvEntries, err := net.LookupSRV("xmpp-client", "tcp", client.JID.Domain)
|
_, srvEntries, err := net.LookupSRV("xmpp-client", "tcp", client.JID.Domain)
|
||||||
addr := client.JID.Domain + ":5222"
|
addr := client.JID.Domain
|
||||||
if err == nil && len(srvEntries) > 0 {
|
if err == nil && len(srvEntries) > 0 {
|
||||||
bestSrv := srvEntries[0]
|
bestSrv := srvEntries[0]
|
||||||
for _, srv := range srvEntries {
|
for _, srv := range srvEntries {
|
||||||
|
@ -52,18 +50,19 @@ func (client *Client) Connect(password string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
a := strings.SplitN(addr, ":", 2)
|
if strings.LastIndex(addr, ":") <= strings.LastIndex(addr, "]") {
|
||||||
if len(a) == 1 {
|
|
||||||
addr += ":5222"
|
addr += ":5222"
|
||||||
}
|
}
|
||||||
if client.Protocol == "" {
|
if client.Protocol == "" {
|
||||||
client.Protocol = "tcp"
|
client.Protocol = "tcp"
|
||||||
}
|
}
|
||||||
|
client.Logging.Debug("try tcp connect")
|
||||||
conn, err := net.DialTimeout(client.Protocol, addr, client.Timeout)
|
conn, err := net.DialTimeout(client.Protocol, addr, client.Timeout)
|
||||||
client.setConnection(conn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
client.Logging.Debug("tcp connected")
|
||||||
|
client.setConnection(conn)
|
||||||
|
|
||||||
if err = client.connect(password); err != nil {
|
if err = client.connect(password); err != nil {
|
||||||
client.Close()
|
client.Close()
|
||||||
|
@ -74,7 +73,7 @@ func (client *Client) Connect(password string) error {
|
||||||
|
|
||||||
// Close closes the XMPP connection
|
// Close closes the XMPP connection
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
if c.conn != (*tls.Conn)(nil) {
|
if c.conn != nil {
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1 +1,40 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
|
/*
|
||||||
|
func TestClient(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
jid := xmppbase.NewJID("test@example.net")
|
||||||
|
|
||||||
|
logger := log.New()
|
||||||
|
logger.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
JID: jid,
|
||||||
|
Timeout: time.Millisecond * 500,
|
||||||
|
Logging: logger.WithField("jid", jid.String()),
|
||||||
|
}
|
||||||
|
// close nil connected
|
||||||
|
assert.NoError(client.Close())
|
||||||
|
|
||||||
|
err := client.Connect("password")
|
||||||
|
assert.Error(err)
|
||||||
|
assert.Contains(err.Error(), "timeout")
|
||||||
|
|
||||||
|
jid.Domain = "chat.sum7.eu"
|
||||||
|
|
||||||
|
// invalid auth
|
||||||
|
client, err = NewClient(jid, "password")
|
||||||
|
assert.NotNil(client)
|
||||||
|
assert.Error(err)
|
||||||
|
assert.Contains(err.Error(), "not-authorized : ")
|
||||||
|
// already closed
|
||||||
|
assert.Error(client.Close())
|
||||||
|
|
||||||
|
client.Logging = logger.WithField("jid", jid.String())
|
||||||
|
|
||||||
|
err = client.Connect("FqzMp6bevlHlt8d")
|
||||||
|
assert.NoError(err)
|
||||||
|
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -6,9 +6,9 @@ import (
|
||||||
"dev.sum7.eu/genofire/yaja/xmpp"
|
"dev.sum7.eu/genofire/yaja/xmpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (client *Client) Read() (*xml.StartElement, error) {
|
func read(decoder *xml.Decoder) (*xml.StartElement, error) {
|
||||||
for {
|
for {
|
||||||
nextToken, err := client.in.Token()
|
nextToken, err := decoder.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,9 @@ func (client *Client) Read() (*xml.StartElement, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (client *Client) Read() (*xml.StartElement, error) {
|
||||||
|
return read(client.in)
|
||||||
|
}
|
||||||
func (client *Client) Decode(p interface{}, element *xml.StartElement) error {
|
func (client *Client) Decode(p interface{}, element *xml.StartElement) error {
|
||||||
err := client.in.DecodeElement(p, element)
|
err := client.in.DecodeElement(p, element)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -43,7 +46,7 @@ func (client *Client) ReadDecode(p interface{}) error {
|
||||||
iq = &xmpp.IQClient{}
|
iq = &xmpp.IQClient{}
|
||||||
}
|
}
|
||||||
err = client.Decode(iq, element)
|
err = client.Decode(iq, element)
|
||||||
if err == nil && iq.Ping != nil {
|
if err == nil && iq.Ping != nil && iq.Type == xmpp.IQTypeGet {
|
||||||
client.Logging.Info("client.ReadElement: auto answer ping")
|
client.Logging.Info("client.ReadElement: auto answer ping")
|
||||||
iq.Type = xmpp.IQTypeResult
|
iq.Type = xmpp.IQTypeResult
|
||||||
iq.To = iq.From
|
iq.To = iq.From
|
||||||
|
@ -56,35 +59,32 @@ func (client *Client) ReadDecode(p interface{}) error {
|
||||||
}
|
}
|
||||||
return client.Decode(p, element)
|
return client.Decode(p, element)
|
||||||
}
|
}
|
||||||
func (client *Client) encode(p interface{}) error {
|
func (client *Client) send(p interface{}) error {
|
||||||
err := client.out.Encode(p)
|
b, err := xml.Marshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
client.Logging.Warnf("error send %v", p)
|
||||||
return err
|
return err
|
||||||
} else {
|
|
||||||
if b, err := xml.Marshal(p); err == nil {
|
|
||||||
client.Logging.Debugf("encode %v", string(b))
|
|
||||||
} else {
|
|
||||||
client.Logging.Debugf("encode %v", p)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
client.Logging.Debugf("send %v", string(b))
|
||||||
|
_, err = client.conn.Write(b)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) Send(p interface{}) error {
|
func (client *Client) Send(p interface{}) error {
|
||||||
msg, ok := p.(*xmpp.MessageClient)
|
msg, ok := p.(*xmpp.MessageClient)
|
||||||
if ok {
|
if ok {
|
||||||
msg.From = client.JID
|
msg.From = client.JID
|
||||||
return client.encode(msg)
|
return client.send(msg)
|
||||||
}
|
}
|
||||||
iq, ok := p.(*xmpp.IQClient)
|
iq, ok := p.(*xmpp.IQClient)
|
||||||
if ok {
|
if ok {
|
||||||
iq.From = client.JID
|
iq.From = client.JID
|
||||||
return client.encode(iq)
|
return client.send(iq)
|
||||||
}
|
}
|
||||||
pc, ok := p.(*xmpp.PresenceClient)
|
pc, ok := p.(*xmpp.PresenceClient)
|
||||||
if ok {
|
if ok {
|
||||||
pc.From = client.JID
|
pc.From = client.JID
|
||||||
return client.encode(pc)
|
return client.send(pc)
|
||||||
}
|
}
|
||||||
return client.encode(p)
|
return client.send(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"dev.sum7.eu/genofire/yaja/xmpp"
|
||||||
|
"dev.sum7.eu/genofire/yaja/xmpp/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRead(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
server, clientConn := net.Pipe()
|
||||||
|
client := &Client{}
|
||||||
|
client.setConnection(clientConn)
|
||||||
|
|
||||||
|
go server.Write([]byte(`<message>`))
|
||||||
|
|
||||||
|
element, err := client.Read()
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("message", element.Name.Local)
|
||||||
|
|
||||||
|
go server.Write([]byte(`<>`))
|
||||||
|
element, err = client.Read()
|
||||||
|
assert.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSend(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
server, clientConn := net.Pipe()
|
||||||
|
client := &Client{
|
||||||
|
Logging: log.WithField("test", "send"),
|
||||||
|
}
|
||||||
|
client.setConnection(clientConn)
|
||||||
|
serverDecoder := xml.NewDecoder(server)
|
||||||
|
wgWait := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
wgWait.Add(1)
|
||||||
|
go func() {
|
||||||
|
err := client.Send(3)
|
||||||
|
assert.NoError(err)
|
||||||
|
wgWait.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
element, err := read(serverDecoder)
|
||||||
|
wgWait.Wait()
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("int", element.Name.Local)
|
||||||
|
|
||||||
|
wgWait.Add(1)
|
||||||
|
go func() {
|
||||||
|
err := client.Send(&xmpp.MessageClient{To: xmppbase.NewJID("a@a.de")})
|
||||||
|
assert.NoError(err)
|
||||||
|
wgWait.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
element, err = read(serverDecoder)
|
||||||
|
wgWait.Wait()
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("message", element.Name.Local)
|
||||||
|
assert.Equal("a@a.de", element.Attr[1].Value)
|
||||||
|
|
||||||
|
wgWait.Add(1)
|
||||||
|
go func() {
|
||||||
|
err := client.Send(&xmpp.IQClient{Type: xmpp.IQTypeGet})
|
||||||
|
assert.NoError(err)
|
||||||
|
wgWait.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
element, err = read(serverDecoder)
|
||||||
|
wgWait.Wait()
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("iq", element.Name.Local)
|
||||||
|
assert.Equal("get", element.Attr[2].Value)
|
||||||
|
|
||||||
|
wgWait.Add(1)
|
||||||
|
go func() {
|
||||||
|
err := client.Send(&xmpp.PresenceClient{Type: xmpp.PresenceTypeSubscribe})
|
||||||
|
assert.NoError(err)
|
||||||
|
wgWait.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
element, err = read(serverDecoder)
|
||||||
|
wgWait.Wait()
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.Equal("presence", element.Name.Local)
|
||||||
|
assert.Equal("subscribe", element.Attr[1].Value)
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTLS(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
client := &Client{}
|
||||||
|
assert.Nil(client.TLSConnectionState())
|
||||||
|
client.conn = &tls.Conn{}
|
||||||
|
assert.NotNil(client.TLSConnectionState())
|
||||||
|
}
|
Reference in New Issue