sum7
/
yaja
Archived
1
0
Fork 0
This repository has been archived on 2020-09-27. You can view files and clone it, but cannot push or open issues or pull requests.
yaja/server/state/connect.go

117 lines
2.9 KiB
Go
Raw Normal View History

2017-12-16 23:20:46 +01:00
package state
import (
"crypto/tls"
"fmt"
2018-02-07 15:34:18 +01:00
"dev.sum7.eu/genofire/yaja/messages"
"dev.sum7.eu/genofire/yaja/model"
"dev.sum7.eu/genofire/yaja/server/utils"
2017-12-16 23:20:46 +01:00
"golang.org/x/crypto/acme/autocert"
)
// Start state
type Start struct {
2017-12-17 17:50:51 +01:00
Next State
Client *utils.Client
2017-12-16 23:20:46 +01:00
}
// Process message
2017-12-17 17:50:51 +01:00
func (state *Start) Process() State {
state.Client.Log = state.Client.Log.WithField("state", "stream")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
2017-12-16 23:20:46 +01:00
2017-12-17 17:50:51 +01:00
element, err := state.Client.Read()
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to read: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
if element.Name.Space != messages.NSStream || element.Name.Local != "stream" {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("is no stream")
return state
2017-12-16 23:20:46 +01:00
}
for _, attr := range element.Attr {
if attr.Name.Local == "to" {
2017-12-17 17:50:51 +01:00
state.Client.JID = &model.JID{Domain: attr.Value}
state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full())
2017-12-16 23:20:46 +01:00
}
}
2017-12-17 17:50:51 +01:00
if state.Client.JID == nil {
state.Client.Log.Warn("no 'to' domain readed")
return nil
2017-12-16 23:20:46 +01:00
}
2017-12-17 17:50:51 +01:00
fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?>
2017-12-16 23:20:46 +01:00
<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`,
2018-02-11 19:35:32 +01:00
messages.CreateCookie(), messages.NSClient, messages.NSStream)
2017-12-16 23:20:46 +01:00
2017-12-17 17:50:51 +01:00
fmt.Fprintf(state.Client.Conn, `<stream:features>
2017-12-16 23:20:46 +01:00
<starttls xmlns='%s'>
<required/>
</starttls>
</stream:features>`,
messages.NSStream)
2017-12-17 17:50:51 +01:00
return state.Next
2017-12-16 23:20:46 +01:00
}
// TLSUpgrade state
type TLSUpgrade struct {
2017-12-17 17:50:51 +01:00
Next State
Client *utils.Client
TLSConfig *tls.Config
TLSManager *autocert.Manager
2017-12-16 23:20:46 +01:00
}
// Process message
2017-12-17 17:50:51 +01:00
func (state *TLSUpgrade) Process() State {
state.Client.Log = state.Client.Log.WithField("state", "tls upgrade")
state.Client.Log.Debug("running")
defer state.Client.Log.Debug("leave")
2017-12-16 23:20:46 +01:00
2017-12-17 17:50:51 +01:00
element, err := state.Client.Read()
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to read: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("is no starttls", element)
return nil
2017-12-16 23:20:46 +01:00
}
2017-12-17 17:50:51 +01:00
fmt.Fprintf(state.Client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS)
2017-12-16 23:20:46 +01:00
// perform the TLS handshake
var tlsConfig *tls.Config
2017-12-17 17:50:51 +01:00
if m := state.TLSManager; m != nil {
2017-12-16 23:20:46 +01:00
var cert *tls.Certificate
2017-12-17 17:50:51 +01:00
cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: state.Client.JID.Domain})
2017-12-16 23:20:46 +01:00
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("no cert in tls manger found: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{*cert},
}
}
if tlsConfig == nil {
2017-12-17 17:50:51 +01:00
tlsConfig = state.TLSConfig
2017-12-16 23:20:46 +01:00
if tlsConfig != nil {
2017-12-17 17:50:51 +01:00
tlsConfig.ServerName = state.Client.JID.Domain
2017-12-16 23:20:46 +01:00
} else {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("no tls config found: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
}
2017-12-17 17:50:51 +01:00
tlsConn := tls.Server(state.Client.Conn, tlsConfig)
2017-12-16 23:20:46 +01:00
err = tlsConn.Handshake()
if err != nil {
2017-12-17 17:50:51 +01:00
state.Client.Log.Warn("unable to tls handshake: ", err)
return nil
2017-12-16 23:20:46 +01:00
}
// restart the Connection
2017-12-17 17:50:51 +01:00
state.Client.SetConnecting(tlsConn)
2017-12-16 23:20:46 +01:00
2017-12-17 17:50:51 +01:00
return state.Next
2017-12-16 23:20:46 +01:00
}