diff --git a/cmd/serve.go b/cmd/serve.go index 009213b..16d39ea 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -1,12 +1,19 @@ package cmd import ( + "crypto/tls" "os" "os/signal" "syscall" + "time" + "golang.org/x/crypto/acme/autocert" + + "dev.sum7.eu/genofire/yaja/model" "dev.sum7.eu/genofire/yaja/model/config" + "dev.sum7.eu/genofire/yaja/server" + "github.com/genofire/golang-lib/worker" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -14,28 +21,137 @@ import ( var configPath string +var ( + configData *config.Config + state *model.State + statesaveWorker *worker.Worker + srv *server.Server + certs *tls.Config +) + // serveCmd represents the serve command var serveCmd = &cobra.Command{ Use: "serve", Short: "Runs the yaja server", Example: "yaja serve -c /etc/yaja.conf", Run: func(cmd *cobra.Command, args []string) { - _, err := config.ReadConfigFile(configPath) + var err error + configData, err = config.ReadConfigFile(configPath) if err != nil { log.Fatal("unable to load config file:", err) } + state, err = model.ReadState(configData.StatePath) + if err != nil { + log.Warn("unable to load state file:", err) + } + + statesaveWorker = worker.NewWorker(time.Minute, func() { + model.SaveJSON(state, configData.StatePath) + log.Info("save state to:", configData.StatePath) + }) + + m := autocert.Manager{ + Cache: autocert.DirCache(configData.TLSDir), + Prompt: autocert.AcceptTOS, + } + + certs = &tls.Config{GetCertificate: m.GetCertificate} + + srv = &server.Server{ + TLSConfig: certs, + State: state, + PortClient: configData.PortClient, + PortServer: configData.PortServer, + } + + go statesaveWorker.Start() + go srv.Start() + log.Infoln("yaja started ") // Wait for INT/TERM sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - sig := <-sigs - log.Infoln("received", sig) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGUSR1) + for sig := range sigs { + log.Infoln("received", sig) + switch sig { + case syscall.SIGTERM: + log.Panic("terminated") + os.Exit(0) + case syscall.SIGQUIT: + quit() + case syscall.SIGHUP: + quit() + case syscall.SIGUSR1: + reload() + } + } }, } +func quit() { + srv.Close() + statesaveWorker.Close() + + model.SaveJSON(state, configData.StatePath) +} + +func reload() { + log.Info("start reloading...") + configNewData, err := config.ReadConfigFile(configPath) + if err != nil { + log.Warn("unable to load config file:", err) + return + } + + if configNewData.StatePath != configData.StatePath { + statesaveWorker.Close() + statesaveWorker := worker.NewWorker(time.Minute, func() { + model.SaveJSON(state, configNewData.StatePath) + log.Info("save state to:", configNewData.StatePath) + }) + go statesaveWorker.Start() + } + + restartServer := false + + if configNewData.TLSDir != configData.TLSDir { + + m := autocert.Manager{ + Cache: autocert.DirCache(configData.TLSDir), + Prompt: autocert.AcceptTOS, + } + + certs = &tls.Config{GetCertificate: m.GetCertificate} + restartServer = true + } + + newServer := &server.Server{ + TLSConfig: certs, + State: state, + PortClient: configNewData.PortClient, + PortServer: configNewData.PortServer, + } + + if configNewData.PortServer != configData.PortServer { + restartServer = true + } + if configNewData.PortClient != configData.PortClient { + restartServer = true + } + if restartServer { + go srv.Start() + //TODO should fetch new server error + srv.Close() + srv = newServer + } + + configData = configNewData + log.Info("reloaded") +} + func init() { RootCmd.AddCommand(serveCmd) serveCmd.Flags().StringVarP(&configPath, "config", "c", "yaja.conf", "Path to configuration file") diff --git a/config_example.conf b/config_example.conf index fcf3123..75e686b 100644 --- a/config_example.conf +++ b/config_example.conf @@ -1 +1,4 @@ -tlsdir = "/tmp" +tlsdir = "/tmp/ssl" +state_path = "/tmp/yaja.json" +port_client = 5222 +port_server = 5269 diff --git a/model/account.go b/model/account.go new file mode 100644 index 0000000..6c638af --- /dev/null +++ b/model/account.go @@ -0,0 +1,41 @@ +package model + +import ( + "errors" + "sync" +) + +type Domain struct { + FQDN string + Accounts map[string]*Account + sync.Mutex +} + +func (d *Domain) GetJID() *JID { + return &JID{ + Domain: d.FQDN, + } +} + +func (d *Domain) UpdateAccount(a *Account) error { + if a.Local == "" { + return errors.New("No localpart exists in account") + } + d.Lock() + d.Accounts[a.Local] = a + d.Unlock() + a.Domain = d + return nil +} + +type Account struct { + Local string + Domain *Domain +} + +func (a *Account) GetJID() *JID { + return &JID{ + Domain: a.Domain.FQDN, + Local: a.Local, + } +} diff --git a/model/config/file_test.go b/model/config/file_test.go index 8b3b892..3368c04 100644 --- a/model/config/file_test.go +++ b/model/config/file_test.go @@ -13,7 +13,7 @@ func TestReadConfig(t *testing.T) { assert.NoError(err) assert.NotNil(config) - assert.Equal("/tmp", config.TLSDir) + assert.Equal("/tmp/ssl", config.TLSDir) config, err = ReadConfigFile("../config_example.co") assert.Nil(config) diff --git a/model/config/struct.go b/model/config/struct.go index 3eafb60..83fe37f 100644 --- a/model/config/struct.go +++ b/model/config/struct.go @@ -1,19 +1,8 @@ package config -import "dev.sum7.eu/genofire/yaja/model" - type Config struct { TLSDir string `toml:"tlsdir"` + StatePath string `toml:"state_path"` PortClient int `toml:"port_client"` PortServer int `toml:"port_server"` - - Domain []*Domain `toml:"domain"` -} - -type Domain struct { - FQDN string `toml:"fqdn"` - Admins []*model.JID `toml:"admins"` - TLSDisable bool `toml:"tls_disable"` - TLSPrivate string `toml:"tls_private"` - TLSPublic string `toml:"tls_public"` } diff --git a/model/file.go b/model/file.go new file mode 100644 index 0000000..493f62d --- /dev/null +++ b/model/file.go @@ -0,0 +1,27 @@ +package model + +import ( + "encoding/json" + "log" + "os" +) + +// SaveJSON to path +func SaveJSON(input interface{}, outputFile string) { + tmpFile := outputFile + ".tmp" + + f, err := os.OpenFile(tmpFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Panic(err) + } + + err = json.NewEncoder(f).Encode(input) + if err != nil { + log.Panic(err) + } + + f.Close() + if err := os.Rename(tmpFile, outputFile); err != nil { + log.Panic(err) + } +} diff --git a/model/jid.go b/model/jid.go index 4a8d1b9..0c67867 100644 --- a/model/jid.go +++ b/model/jid.go @@ -50,13 +50,13 @@ func (jid *JID) Full() string { return jid.Bare() } -//MarshalJSON to bytearray -func (jid JID) MarshalJSON() ([]byte, error) { +//MarshalTOML to bytearray +func (jid JID) MarshalTOML() ([]byte, error) { return []byte(jid.Full()), nil } -// UnmarshalJSON from bytearray -func (jid *JID) UnmarshalJSON(data []byte) (err error) { +// UnmarshalTOML from bytearray +func (jid *JID) UnmarshalTOML(data []byte) (err error) { newJID := NewJID(string(data)) if newJID == nil { return errors.New("not a valid jid") diff --git a/model/jid_test.go b/model/jid_test.go index 3d05224..52ea722 100644 --- a/model/jid_test.go +++ b/model/jid_test.go @@ -133,18 +133,18 @@ func TestJIDBare(t *testing.T) { } } -func TestJSONMarshal(t *testing.T) { +func TestMarshal(t *testing.T) { assert := assert.New(t) jid := &JID{} - err := jid.UnmarshalJSON([]byte("juliet@example.com/foo")) + err := jid.UnmarshalTOML([]byte("juliet@example.com/foo")) assert.NoError(err) assert.Equal(jid.Local, "juliet") assert.Equal(jid.Domain, "example.com") assert.Equal(jid.Resource, "foo") - err = jid.UnmarshalJSON([]byte("juliet@example.com/ foo")) + err = jid.UnmarshalTOML([]byte("juliet@example.com/ foo")) assert.Error(err) @@ -153,7 +153,7 @@ func TestJSONMarshal(t *testing.T) { Domain: "example.com", Resource: "bar", } - jidString, err := jid.MarshalJSON() + jidString, err := jid.MarshalTOML() assert.NoError(err) assert.Equal("romeo@example.com/bar", string(jidString)) } diff --git a/model/state.go b/model/state.go new file mode 100644 index 0000000..16cbed1 --- /dev/null +++ b/model/state.go @@ -0,0 +1,38 @@ +package model + +import ( + "encoding/json" + "errors" + "os" + "sync" +) + +type State struct { + Domains map[string]*Domain + sync.Mutex +} + +func ReadState(path string) (state *State, err error) { + state = &State{} + + if f, err := os.Open(path); err == nil { // transform data to legacy meshviewer + if err = json.NewDecoder(f).Decode(state); err == nil { + return state, nil + + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (s *State) UpdateDomain(d *Domain) error { + if d.FQDN == "" { + return errors.New("No fqdn exists in domain") + } + s.Lock() + s.Domains[d.FQDN] = d + s.Unlock() + return nil +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..5298e87 --- /dev/null +++ b/server/server.go @@ -0,0 +1,22 @@ +package server + +import ( + "crypto/tls" + + "dev.sum7.eu/genofire/yaja/model" +) + +type Server struct { + TLSConfig *tls.Config + PortClient int + PortServer int + State *model.State +} + +func (srv *Server) Start() { + +} + +func (srv *Server) Close() { + +}