[TASK] add ssh List for handling result of everywhere run

This commit is contained in:
Martin Geno 2017-05-30 11:17:16 +02:00
parent f3e9b6c0db
commit 6046e7dd95
No known key found for this signature in database
GPG Key ID: F0D39A37E925E941
8 changed files with 123 additions and 27 deletions

View File

@ -58,7 +58,7 @@ func (nodes *Nodes) LearnNode(n *yanic.Node) {
// session := nodes.ssh.ConnectTo(node.Address) // session := nodes.ssh.ConnectTo(node.Address)
result, err := nodes.ssh.RunOn(node.GetAddress(nodes.iface), "uptime") result, err := nodes.ssh.RunOn(node.GetAddress(nodes.iface), "uptime")
if err != nil { if err != nil {
logger.Error("init ssh command not run") logger.Error("init ssh command not run", err)
return return
} }
uptime := ssh.SSHResultToString(result) uptime := ssh.SSHResultToString(result)

View File

@ -13,25 +13,27 @@ func (m *Manager) ExecuteEverywhere(cmd string) {
} }
} }
func (m *Manager) ExecuteOn(addr net.TCPAddr, cmd string) { func (m *Manager) ExecuteOn(addr net.TCPAddr, cmd string) error {
client := m.ConnectTo(addr) client, err := m.ConnectTo(addr)
if client != nil { if err != nil {
m.execute(addr.IP.String(), client, cmd) return err
} }
return m.execute(addr.IP.String(), client, cmd)
} }
func (m *Manager) execute(host string, client *ssh.Client, cmd string) { func (m *Manager) execute(host string, client *ssh.Client, cmd string) error {
session, err := client.NewSession() session, err := client.NewSession()
defer session.Close() defer session.Close()
if err != nil { if err != nil {
log.Log.Warnf("can not create session on %s: %s", host, err) log.Log.Warnf("can not create session on %s: %s", host, err)
delete(m.clients, host) delete(m.clients, host)
return return err
} }
err = session.Run(cmd) err = session.Run(cmd)
if err != nil { if err != nil {
log.Log.Warnf("could not run %s on %s: %s", cmd, host, err) log.Log.Warnf("could not run %s on %s: %s", cmd, host, err)
delete(m.clients, host) return err
} }
return nil
} }

View File

@ -10,14 +10,21 @@ import (
func TestExecute(t *testing.T) { func TestExecute(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
addr := net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}
mgmt := NewManager("~/.ssh/id_rsa") mgmt := NewManager("~/.ssh/id_rsa")
assert.NotNil(mgmt, "no new manager created") assert.NotNil(mgmt, "no new manager created")
mgmt.ConnectTo(net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}) _, err := mgmt.ConnectTo(addr)
assert.NoError(err)
mgmt.ExecuteEverywhere("echo $HOSTNAME") mgmt.ExecuteEverywhere("echo $HOSTNAME")
mgmt.ExecuteOn(net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}, "uptime") err = mgmt.ExecuteOn(addr, "uptime")
mgmt.ExecuteOn(net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}, "echo $HOSTNAME") assert.NoError(err)
err = mgmt.ExecuteOn(addr, "echo $HOSTNAME")
assert.NoError(err)
err = mgmt.ExecuteOn(addr, "exit 1")
assert.Error(err)
mgmt.Close() mgmt.Close()
} }

52
ssh/list.go Normal file
View File

@ -0,0 +1,52 @@
package ssh
import (
"sync"
"golang.org/x/crypto/ssh"
)
type List struct {
cmd string
Clients map[string]*ListResult
sshManager *Manager
}
type ListResult struct {
ssh *ssh.Client
Runned bool
WithError bool
Result string
}
func (m *Manager) CreateList(cmd string) *List {
list := &List{
cmd: cmd,
sshManager: m,
Clients: make(map[string]*ListResult),
}
for host, client := range m.clients {
list.Clients[host] = &ListResult{Runned: false, ssh: client}
}
return list
}
func (l List) Run() {
wg := new(sync.WaitGroup)
for host, client := range l.Clients {
wg.Add(1)
go l.runlistelement(host, client, wg)
}
wg.Wait()
}
func (l List) runlistelement(host string, client *ListResult, wg *sync.WaitGroup) {
defer wg.Done()
result, err := l.sshManager.run(host, client.ssh, l.cmd)
client.Runned = true
if err != nil {
client.WithError = true
return
}
client.Result = SSHResultToString(result)
}

36
ssh/list_test.go Normal file
View File

@ -0,0 +1,36 @@
package ssh
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestList(t *testing.T) {
assert := assert.New(t)
addr := net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}
mgmt := NewManager("~/.ssh/id_rsa")
assert.NotNil(mgmt, "no new manager created")
_, err := mgmt.ConnectTo(addr)
assert.NoError(err)
list := mgmt.CreateList("exit 1")
assert.Len(list.Clients, 1)
client := list.Clients[addr.IP.String()]
assert.False(client.Runned)
list.Run()
assert.True(client.Runned)
assert.True(client.WithError)
assert.Equal("", client.Result)
list = mgmt.CreateList("echo 15")
assert.Len(list.Clients, 1)
client = list.Clients[addr.IP.String()]
assert.False(client.Runned)
list.Run()
assert.True(client.Runned)
assert.False(client.WithError)
assert.Equal("15", client.Result)
}

View File

@ -1,6 +1,7 @@
package ssh package ssh
import ( import (
"errors"
"net" "net"
"strings" "strings"
"sync" "sync"
@ -40,19 +41,19 @@ func NewManager(file string) *Manager {
} }
} }
func (m *Manager) ConnectTo(addr net.TCPAddr) *ssh.Client { func (m *Manager) ConnectTo(addr net.TCPAddr) (*ssh.Client, error) {
m.clientsMUX.Lock() m.clientsMUX.Lock()
defer m.clientsMUX.Unlock() defer m.clientsMUX.Unlock()
if t, ok := m.clientsBlacklist[addr.IP.String()]; ok { if t, ok := m.clientsBlacklist[addr.IP.String()]; ok {
if time.Now().Add(-time.Hour * 24).Before(t) { if time.Now().Add(-time.Hour * 24).Before(t) {
return nil return nil, errors.New("node on blacklist")
} else { } else {
delete(m.clientsBlacklist, addr.IP.String()) delete(m.clientsBlacklist, addr.IP.String())
} }
} }
if client, ok := m.clients[addr.IP.String()]; ok { if client, ok := m.clients[addr.IP.String()]; ok {
return client return client, nil
} }
client, err := ssh.Dial("tcp", addr.String(), m.config) client, err := ssh.Dial("tcp", addr.String(), m.config)
@ -60,14 +61,13 @@ func (m *Manager) ConnectTo(addr net.TCPAddr) *ssh.Client {
if strings.Contains(err.Error(), "no supported methods remain") { if strings.Contains(err.Error(), "no supported methods remain") {
m.clientsBlacklist[addr.IP.String()] = time.Now() m.clientsBlacklist[addr.IP.String()] = time.Now()
log.Log.Warnf("node was set on the blacklist: %s", err) log.Log.Warnf("node was set on the blacklist: %s", err)
} else { return nil, errors.New("node on blacklist")
log.Log.Error(err)
} }
return nil return nil, err
} }
m.clients[addr.IP.String()] = client m.clients[addr.IP.String()] = client
return client return client, nil
} }
func (m *Manager) Close() { func (m *Manager) Close() {

View File

@ -2,7 +2,6 @@ package ssh
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"net" "net"
@ -35,11 +34,11 @@ func (m *Manager) RunEverywhere(cmd string, handler SSHResultHandler) {
} }
func (m *Manager) RunOn(addr net.TCPAddr, cmd string) ([]byte, error) { func (m *Manager) RunOn(addr net.TCPAddr, cmd string) ([]byte, error) {
client := m.ConnectTo(addr) client, err := m.ConnectTo(addr)
if client != nil { if err != nil {
return m.run(addr.IP.String(), client, cmd) return nil, err
} }
return nil, errors.New("no connection for runOn") return m.run(addr.IP.String(), client, cmd)
} }
func (m *Manager) run(host string, client *ssh.Client, cmd string) ([]byte, error) { func (m *Manager) run(host string, client *ssh.Client, cmd string) ([]byte, error) {
@ -62,7 +61,6 @@ func (m *Manager) run(host string, client *ssh.Client, cmd string) ([]byte, erro
err = session.Run(cmd) err = session.Run(cmd)
if err != nil { if err != nil {
log.Log.Warnf("could not run %s on %s: %s", cmd, host, err) log.Log.Warnf("could not run %s on %s: %s", cmd, host, err)
delete(m.clients, host)
return nil, err return nil, err
} }
var result []byte var result []byte

View File

@ -10,18 +10,19 @@ import (
func TestRun(t *testing.T) { func TestRun(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
addr := net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}
mgmt := NewManager("~/.ssh/id_rsa") mgmt := NewManager("~/.ssh/id_rsa")
assert.NotNil(mgmt, "no new manager created") assert.NotNil(mgmt, "no new manager created")
mgmt.ConnectTo(net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}) _, err := mgmt.ConnectTo(addr)
assert.NoError(err)
mgmt.RunEverywhere("echo 13", SSHResultToStringHandler(func(result string, err error) { mgmt.RunEverywhere("echo 13", SSHResultToStringHandler(func(result string, err error) {
assert.NoError(err) assert.NoError(err)
assert.Equal("13", result) assert.Equal("13", result)
})) }))
result, err := mgmt.RunOn(net.TCPAddr{IP: net.ParseIP("2a06:8782:ffbb:1337::127"), Port: 22}, "echo 16") result, err := mgmt.RunOn(addr, "echo 16")
assert.NoError(err) assert.NoError(err)
str := SSHResultToString(result) str := SSHResultToString(result)