add cmd reload + state and server struct
This commit is contained in:
		
							parent
							
								
									54f67245d6
								
							
						
					
					
						commit
						8c60ef89c6
					
				
							
								
								
									
										124
									
								
								cmd/serve.go
								
								
								
								
							
							
						
						
									
										124
									
								
								cmd/serve.go
								
								
								
								
							|  | @ -1,12 +1,19 @@ | ||||||
| package cmd | package cmd | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/tls" | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
| 	"syscall" | 	"syscall" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"golang.org/x/crypto/acme/autocert" | ||||||
|  | 
 | ||||||
|  | 	"dev.sum7.eu/genofire/yaja/model" | ||||||
| 	"dev.sum7.eu/genofire/yaja/model/config" | 	"dev.sum7.eu/genofire/yaja/model/config" | ||||||
| 
 | 
 | ||||||
|  | 	"dev.sum7.eu/genofire/yaja/server" | ||||||
|  | 	"github.com/genofire/golang-lib/worker" | ||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 
 | 
 | ||||||
| 	"github.com/spf13/cobra" | 	"github.com/spf13/cobra" | ||||||
|  | @ -14,28 +21,137 @@ import ( | ||||||
| 
 | 
 | ||||||
| var configPath string | var configPath string | ||||||
| 
 | 
 | ||||||
|  | var ( | ||||||
|  | 	configData      *config.Config | ||||||
|  | 	state           *model.State | ||||||
|  | 	statesaveWorker *worker.Worker | ||||||
|  | 	srv             *server.Server | ||||||
|  | 	certs           *tls.Config | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| // serveCmd represents the serve command
 | // serveCmd represents the serve command
 | ||||||
| var serveCmd = &cobra.Command{ | var serveCmd = &cobra.Command{ | ||||||
| 	Use:     "serve", | 	Use:     "serve", | ||||||
| 	Short:   "Runs the yaja server", | 	Short:   "Runs the yaja server", | ||||||
| 	Example: "yaja serve -c /etc/yaja.conf", | 	Example: "yaja serve -c /etc/yaja.conf", | ||||||
| 	Run: func(cmd *cobra.Command, args []string) { | 	Run: func(cmd *cobra.Command, args []string) { | ||||||
| 		_, err := config.ReadConfigFile(configPath) | 		var err error | ||||||
|  | 		configData, err = config.ReadConfigFile(configPath) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Fatal("unable to load config file:", err) | 			log.Fatal("unable to load config file:", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		state, err = model.ReadState(configData.StatePath) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Warn("unable to load state file:", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		statesaveWorker = worker.NewWorker(time.Minute, func() { | ||||||
|  | 			model.SaveJSON(state, configData.StatePath) | ||||||
|  | 			log.Info("save state to:", configData.StatePath) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		m := autocert.Manager{ | ||||||
|  | 			Cache:  autocert.DirCache(configData.TLSDir), | ||||||
|  | 			Prompt: autocert.AcceptTOS, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		certs = &tls.Config{GetCertificate: m.GetCertificate} | ||||||
|  | 
 | ||||||
|  | 		srv = &server.Server{ | ||||||
|  | 			TLSConfig:  certs, | ||||||
|  | 			State:      state, | ||||||
|  | 			PortClient: configData.PortClient, | ||||||
|  | 			PortServer: configData.PortServer, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		go statesaveWorker.Start() | ||||||
|  | 		go srv.Start() | ||||||
|  | 
 | ||||||
| 		log.Infoln("yaja started ") | 		log.Infoln("yaja started ") | ||||||
| 
 | 
 | ||||||
| 		// Wait for INT/TERM
 | 		// Wait for INT/TERM
 | ||||||
| 		sigs := make(chan os.Signal, 1) | 		sigs := make(chan os.Signal, 1) | ||||||
| 		signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | 		signal.Notify(sigs, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGUSR1) | ||||||
| 		sig := <-sigs | 		for sig := range sigs { | ||||||
| 		log.Infoln("received", sig) | 			log.Infoln("received", sig) | ||||||
|  | 			switch sig { | ||||||
|  | 			case syscall.SIGTERM: | ||||||
|  | 				log.Panic("terminated") | ||||||
|  | 				os.Exit(0) | ||||||
|  | 			case syscall.SIGQUIT: | ||||||
|  | 				quit() | ||||||
|  | 			case syscall.SIGHUP: | ||||||
|  | 				quit() | ||||||
|  | 			case syscall.SIGUSR1: | ||||||
|  | 				reload() | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	}, | 	}, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func quit() { | ||||||
|  | 	srv.Close() | ||||||
|  | 	statesaveWorker.Close() | ||||||
|  | 
 | ||||||
|  | 	model.SaveJSON(state, configData.StatePath) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func reload() { | ||||||
|  | 	log.Info("start reloading...") | ||||||
|  | 	configNewData, err := config.ReadConfigFile(configPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Warn("unable to load config file:", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if configNewData.StatePath != configData.StatePath { | ||||||
|  | 		statesaveWorker.Close() | ||||||
|  | 		statesaveWorker := worker.NewWorker(time.Minute, func() { | ||||||
|  | 			model.SaveJSON(state, configNewData.StatePath) | ||||||
|  | 			log.Info("save state to:", configNewData.StatePath) | ||||||
|  | 		}) | ||||||
|  | 		go statesaveWorker.Start() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	restartServer := false | ||||||
|  | 
 | ||||||
|  | 	if configNewData.TLSDir != configData.TLSDir { | ||||||
|  | 
 | ||||||
|  | 		m := autocert.Manager{ | ||||||
|  | 			Cache:  autocert.DirCache(configData.TLSDir), | ||||||
|  | 			Prompt: autocert.AcceptTOS, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		certs = &tls.Config{GetCertificate: m.GetCertificate} | ||||||
|  | 		restartServer = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	newServer := &server.Server{ | ||||||
|  | 		TLSConfig:  certs, | ||||||
|  | 		State:      state, | ||||||
|  | 		PortClient: configNewData.PortClient, | ||||||
|  | 		PortServer: configNewData.PortServer, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if configNewData.PortServer != configData.PortServer { | ||||||
|  | 		restartServer = true | ||||||
|  | 	} | ||||||
|  | 	if configNewData.PortClient != configData.PortClient { | ||||||
|  | 		restartServer = true | ||||||
|  | 	} | ||||||
|  | 	if restartServer { | ||||||
|  | 		go srv.Start() | ||||||
|  | 		//TODO should fetch new server error
 | ||||||
|  | 		srv.Close() | ||||||
|  | 		srv = newServer | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	configData = configNewData | ||||||
|  | 	log.Info("reloaded") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func init() { | func init() { | ||||||
| 	RootCmd.AddCommand(serveCmd) | 	RootCmd.AddCommand(serveCmd) | ||||||
| 	serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") | 	serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") | ||||||
|  |  | ||||||
|  | @ -1 +1,4 @@ | ||||||
| tlsdir = "/tmp" | tlsdir = "/tmp/ssl" | ||||||
|  | state_path = "/tmp/yaja.json" | ||||||
|  | port_client = 5222 | ||||||
|  | port_server = 5269 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,41 @@ | ||||||
|  | package model | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Domain struct { | ||||||
|  | 	FQDN     string | ||||||
|  | 	Accounts map[string]*Account | ||||||
|  | 	sync.Mutex | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (d *Domain) GetJID() *JID { | ||||||
|  | 	return &JID{ | ||||||
|  | 		Domain: d.FQDN, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (d *Domain) UpdateAccount(a *Account) error { | ||||||
|  | 	if a.Local == "" { | ||||||
|  | 		return errors.New("No localpart exists in account") | ||||||
|  | 	} | ||||||
|  | 	d.Lock() | ||||||
|  | 	d.Accounts[a.Local] = a | ||||||
|  | 	d.Unlock() | ||||||
|  | 	a.Domain = d | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Account struct { | ||||||
|  | 	Local  string | ||||||
|  | 	Domain *Domain | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (a *Account) GetJID() *JID { | ||||||
|  | 	return &JID{ | ||||||
|  | 		Domain: a.Domain.FQDN, | ||||||
|  | 		Local:  a.Local, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -13,7 +13,7 @@ func TestReadConfig(t *testing.T) { | ||||||
| 	assert.NoError(err) | 	assert.NoError(err) | ||||||
| 	assert.NotNil(config) | 	assert.NotNil(config) | ||||||
| 
 | 
 | ||||||
| 	assert.Equal("/tmp", config.TLSDir) | 	assert.Equal("/tmp/ssl", config.TLSDir) | ||||||
| 
 | 
 | ||||||
| 	config, err = ReadConfigFile("../config_example.co") | 	config, err = ReadConfigFile("../config_example.co") | ||||||
| 	assert.Nil(config) | 	assert.Nil(config) | ||||||
|  |  | ||||||
|  | @ -1,19 +1,8 @@ | ||||||
| package config | package config | ||||||
| 
 | 
 | ||||||
| import "dev.sum7.eu/genofire/yaja/model" |  | ||||||
| 
 |  | ||||||
| type Config struct { | type Config struct { | ||||||
| 	TLSDir     string `toml:"tlsdir"` | 	TLSDir     string `toml:"tlsdir"` | ||||||
|  | 	StatePath  string `toml:"state_path"` | ||||||
| 	PortClient int    `toml:"port_client"` | 	PortClient int    `toml:"port_client"` | ||||||
| 	PortServer int    `toml:"port_server"` | 	PortServer int    `toml:"port_server"` | ||||||
| 
 |  | ||||||
| 	Domain []*Domain `toml:"domain"` |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type Domain struct { |  | ||||||
| 	FQDN       string       `toml:"fqdn"` |  | ||||||
| 	Admins     []*model.JID `toml:"admins"` |  | ||||||
| 	TLSDisable bool         `toml:"tls_disable"` |  | ||||||
| 	TLSPrivate string       `toml:"tls_private"` |  | ||||||
| 	TLSPublic  string       `toml:"tls_public"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,27 @@ | ||||||
|  | package model | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"log" | ||||||
|  | 	"os" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // SaveJSON to path
 | ||||||
|  | func SaveJSON(input interface{}, outputFile string) { | ||||||
|  | 	tmpFile := outputFile + ".tmp" | ||||||
|  | 
 | ||||||
|  | 	f, err := os.OpenFile(tmpFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Panic(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	err = json.NewEncoder(f).Encode(input) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Panic(err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	f.Close() | ||||||
|  | 	if err := os.Rename(tmpFile, outputFile); err != nil { | ||||||
|  | 		log.Panic(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -50,13 +50,13 @@ func (jid *JID) Full() string { | ||||||
| 	return jid.Bare() | 	return jid.Bare() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //MarshalJSON to bytearray
 | //MarshalTOML to bytearray
 | ||||||
| func (jid JID) MarshalJSON() ([]byte, error) { | func (jid JID) MarshalTOML() ([]byte, error) { | ||||||
| 	return []byte(jid.Full()), nil | 	return []byte(jid.Full()), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UnmarshalJSON from bytearray
 | // UnmarshalTOML from bytearray
 | ||||||
| func (jid *JID) UnmarshalJSON(data []byte) (err error) { | func (jid *JID) UnmarshalTOML(data []byte) (err error) { | ||||||
| 	newJID := NewJID(string(data)) | 	newJID := NewJID(string(data)) | ||||||
| 	if newJID == nil { | 	if newJID == nil { | ||||||
| 		return errors.New("not a valid jid") | 		return errors.New("not a valid jid") | ||||||
|  |  | ||||||
|  | @ -133,18 +133,18 @@ func TestJIDBare(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestJSONMarshal(t *testing.T) { | func TestMarshal(t *testing.T) { | ||||||
| 	assert := assert.New(t) | 	assert := assert.New(t) | ||||||
| 
 | 
 | ||||||
| 	jid := &JID{} | 	jid := &JID{} | ||||||
| 	err := jid.UnmarshalJSON([]byte("juliet@example.com/foo")) | 	err := jid.UnmarshalTOML([]byte("juliet@example.com/foo")) | ||||||
| 
 | 
 | ||||||
| 	assert.NoError(err) | 	assert.NoError(err) | ||||||
| 	assert.Equal(jid.Local, "juliet") | 	assert.Equal(jid.Local, "juliet") | ||||||
| 	assert.Equal(jid.Domain, "example.com") | 	assert.Equal(jid.Domain, "example.com") | ||||||
| 	assert.Equal(jid.Resource, "foo") | 	assert.Equal(jid.Resource, "foo") | ||||||
| 
 | 
 | ||||||
| 	err = jid.UnmarshalJSON([]byte("juliet@example.com/ foo")) | 	err = jid.UnmarshalTOML([]byte("juliet@example.com/ foo")) | ||||||
| 
 | 
 | ||||||
| 	assert.Error(err) | 	assert.Error(err) | ||||||
| 
 | 
 | ||||||
|  | @ -153,7 +153,7 @@ func TestJSONMarshal(t *testing.T) { | ||||||
| 		Domain:   "example.com", | 		Domain:   "example.com", | ||||||
| 		Resource: "bar", | 		Resource: "bar", | ||||||
| 	} | 	} | ||||||
| 	jidString, err := jid.MarshalJSON() | 	jidString, err := jid.MarshalTOML() | ||||||
| 	assert.NoError(err) | 	assert.NoError(err) | ||||||
| 	assert.Equal("romeo@example.com/bar", string(jidString)) | 	assert.Equal("romeo@example.com/bar", string(jidString)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,38 @@ | ||||||
|  | package model | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"os" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type State struct { | ||||||
|  | 	Domains map[string]*Domain | ||||||
|  | 	sync.Mutex | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func ReadState(path string) (state *State, err error) { | ||||||
|  | 	state = &State{} | ||||||
|  | 
 | ||||||
|  | 	if f, err := os.Open(path); err == nil { // transform data to legacy meshviewer
 | ||||||
|  | 		if err = json.NewDecoder(f).Decode(state); err == nil { | ||||||
|  | 			return state, nil | ||||||
|  | 
 | ||||||
|  | 		} else { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *State) UpdateDomain(d *Domain) error { | ||||||
|  | 	if d.FQDN == "" { | ||||||
|  | 		return errors.New("No fqdn exists in domain") | ||||||
|  | 	} | ||||||
|  | 	s.Lock() | ||||||
|  | 	s.Domains[d.FQDN] = d | ||||||
|  | 	s.Unlock() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | @ -0,0 +1,22 @@ | ||||||
|  | package server | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/tls" | ||||||
|  | 
 | ||||||
|  | 	"dev.sum7.eu/genofire/yaja/model" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Server struct { | ||||||
|  | 	TLSConfig  *tls.Config | ||||||
|  | 	PortClient int | ||||||
|  | 	PortServer int | ||||||
|  | 	State      *model.State | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (srv *Server) Start() { | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (srv *Server) Close() { | ||||||
|  | 
 | ||||||
|  | } | ||||||
		Reference in New Issue