| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | package state | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							|  |  |  | 	"crypto/tls" | 
					
						
							|  |  |  | 	"fmt" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-07 15:34:18 +01:00
										 |  |  | 	"dev.sum7.eu/genofire/yaja/messages" | 
					
						
							|  |  |  | 	"dev.sum7.eu/genofire/yaja/model" | 
					
						
							|  |  |  | 	"dev.sum7.eu/genofire/yaja/server/utils" | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	"golang.org/x/crypto/acme/autocert" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Start state
 | 
					
						
							|  |  |  | type Start struct { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	Next   State | 
					
						
							|  |  |  | 	Client *utils.Client | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Process message
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | func (state *Start) Process() State { | 
					
						
							|  |  |  | 	state.Client.Log = state.Client.Log.WithField("state", "stream") | 
					
						
							|  |  |  | 	state.Client.Log.Debug("running") | 
					
						
							|  |  |  | 	defer state.Client.Log.Debug("leave") | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	element, err := state.Client.Read() | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		state.Client.Log.Warn("unable to read: ", err) | 
					
						
							|  |  |  | 		return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	if element.Name.Space != messages.NSStream || element.Name.Local != "stream" { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		state.Client.Log.Warn("is no stream") | 
					
						
							|  |  |  | 		return state | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	for _, attr := range element.Attr { | 
					
						
							|  |  |  | 		if attr.Name.Local == "to" { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 			state.Client.JID = &model.JID{Domain: attr.Value} | 
					
						
							|  |  |  | 			state.Client.Log = state.Client.Log.WithField("jid", state.Client.JID.Full()) | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	if state.Client.JID == nil { | 
					
						
							|  |  |  | 		state.Client.Log.Warn("no 'to' domain readed") | 
					
						
							|  |  |  | 		return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	fmt.Fprintf(state.Client.Conn, `<?xml version='1.0'?> | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		<stream:stream id='%x' version='1.0' xmlns='%s' xmlns:stream='%s'>`, | 
					
						
							| 
									
										
										
										
											2018-02-11 19:35:32 +01:00
										 |  |  | 		messages.CreateCookie(), messages.NSClient, messages.NSStream) | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	fmt.Fprintf(state.Client.Conn, `<stream:features> | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 			<starttls xmlns='%s'> | 
					
						
							|  |  |  | 				<required/> | 
					
						
							|  |  |  | 			</starttls> | 
					
						
							|  |  |  | 		</stream:features>`, | 
					
						
							|  |  |  | 		messages.NSStream) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	return state.Next | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // TLSUpgrade state
 | 
					
						
							|  |  |  | type TLSUpgrade struct { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	Next       State | 
					
						
							|  |  |  | 	Client     *utils.Client | 
					
						
							|  |  |  | 	TLSConfig  *tls.Config | 
					
						
							|  |  |  | 	TLSManager *autocert.Manager | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Process message
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | func (state *TLSUpgrade) Process() State { | 
					
						
							|  |  |  | 	state.Client.Log = state.Client.Log.WithField("state", "tls upgrade") | 
					
						
							|  |  |  | 	state.Client.Log.Debug("running") | 
					
						
							|  |  |  | 	defer state.Client.Log.Debug("leave") | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	element, err := state.Client.Read() | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		state.Client.Log.Warn("unable to read: ", err) | 
					
						
							|  |  |  | 		return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	if element.Name.Space != messages.NSTLS || element.Name.Local != "starttls" { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		state.Client.Log.Warn("is no starttls", element) | 
					
						
							|  |  |  | 		return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	fmt.Fprintf(state.Client.Conn, "<proceed xmlns='%s'/>", messages.NSTLS) | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	// perform the TLS handshake
 | 
					
						
							|  |  |  | 	var tlsConfig *tls.Config | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	if m := state.TLSManager; m != nil { | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		var cert *tls.Certificate | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		cert, err = m.GetCertificate(&tls.ClientHelloInfo{ServerName: state.Client.JID.Domain}) | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 			state.Client.Log.Warn("no cert in tls manger found: ", err) | 
					
						
							|  |  |  | 			return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		tlsConfig = &tls.Config{ | 
					
						
							|  |  |  | 			Certificates: []tls.Certificate{*cert}, | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if tlsConfig == nil { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		tlsConfig = state.TLSConfig | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		if tlsConfig != nil { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 			tlsConfig.ServerName = state.Client.JID.Domain | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		} else { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 			state.Client.Log.Warn("no tls config found: ", err) | 
					
						
							|  |  |  | 			return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	tlsConn := tls.Server(state.Client.Conn, tlsConfig) | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	err = tlsConn.Handshake() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 		state.Client.Log.Warn("unable to tls handshake: ", err) | 
					
						
							|  |  |  | 		return nil | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	// restart the Connection
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	state.Client.SetConnecting(tlsConn) | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2017-12-17 17:50:51 +01:00
										 |  |  | 	return state.Next | 
					
						
							| 
									
										
										
										
											2017-12-16 23:20:46 +01:00
										 |  |  | } |