104 lines
2.3 KiB
Go
104 lines
2.3 KiB
Go
package ssh
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// the SSH Connection Manager for multiple connections
|
|
type Manager struct {
|
|
config *ssh.ClientConfig
|
|
clientsBlacklist map[string]time.Time
|
|
clientsMUX sync.Mutex
|
|
timeout time.Duration
|
|
}
|
|
|
|
// create a new SSH Connection Manager by ssh file
|
|
func NewManager(file string, timeout time.Duration) *Manager {
|
|
var auths []ssh.AuthMethod
|
|
if auth := SSHAgent(); auth != nil {
|
|
auths = append(auths, auth)
|
|
}
|
|
if auth := PublicKeyFile(file); auth != nil {
|
|
auths = append(auths, auth)
|
|
}
|
|
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: auths,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
}
|
|
return &Manager{
|
|
config: sshConfig,
|
|
clientsBlacklist: make(map[string]time.Time),
|
|
timeout: timeout,
|
|
}
|
|
}
|
|
|
|
// Conn wraps a net.Conn, and sets a deadline for every read
|
|
// and write operation.
|
|
type Conn struct {
|
|
net.Conn
|
|
ReadTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
}
|
|
|
|
func (c *Conn) Read(b []byte) (int, error) {
|
|
err := c.Conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return c.Conn.Read(b)
|
|
}
|
|
|
|
func (c *Conn) Write(b []byte) (int, error) {
|
|
err := c.Conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return c.Conn.Write(b)
|
|
}
|
|
|
|
func (m *Manager) ConnectTo(addr net.TCPAddr) (*ssh.Client, error) {
|
|
m.clientsMUX.Lock()
|
|
defer m.clientsMUX.Unlock()
|
|
if t, ok := m.clientsBlacklist[addr.IP.String()]; ok {
|
|
if time.Now().Add(-time.Hour * 24).Before(t) {
|
|
return nil, errors.New("node on blacklist")
|
|
} else {
|
|
delete(m.clientsBlacklist, addr.IP.String())
|
|
}
|
|
}
|
|
|
|
addrString := addr.String()
|
|
|
|
conn, err := net.DialTimeout("tcp", addrString, m.timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
timeoutConn := &Conn{conn, m.timeout, m.timeout}
|
|
c, chans, reqs, err := ssh.NewClientConn(timeoutConn, addrString, m.config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := ssh.NewClient(c, chans, reqs)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "no supported methods remain") {
|
|
m.clientsBlacklist[addr.IP.String()] = time.Now()
|
|
log.Warnf("node was set on the blacklist: %s", err)
|
|
return nil, errors.New("node on blacklist")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return client, nil
|
|
}
|