2019-09-05 18:55:53 +00:00
|
|
|
// Copyright (c) 2017 Gorillalabs. All rights reserved.
|
|
|
|
|
|
|
|
package backend
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"regexp"
|
|
|
|
"strings"
|
|
|
|
|
2020-07-20 20:18:46 +00:00
|
|
|
"github.com/pkg/errors"
|
2019-09-05 18:55:53 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
// sshSession exists so we don't create a hard dependency on crypto/ssh.
|
|
|
|
type sshSession interface {
|
|
|
|
Waiter
|
|
|
|
|
|
|
|
StdinPipe() (io.WriteCloser, error)
|
|
|
|
StdoutPipe() (io.Reader, error)
|
|
|
|
StderrPipe() (io.Reader, error)
|
|
|
|
Start(string) error
|
|
|
|
}
|
|
|
|
|
|
|
|
type SSH struct {
|
|
|
|
Session sshSession
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *SSH) StartProcess(cmd string, args ...string) (Waiter, io.Writer, io.Reader, io.Reader, error) {
|
|
|
|
stdin, err := b.Session.StdinPipe()
|
|
|
|
if err != nil {
|
2020-07-20 20:18:46 +00:00
|
|
|
return nil, nil, nil, nil, errors.Wrap(err, "Could not get hold of the SSH session's stdin stream")
|
2019-09-05 18:55:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
stdout, err := b.Session.StdoutPipe()
|
|
|
|
if err != nil {
|
2020-07-20 20:18:46 +00:00
|
|
|
return nil, nil, nil, nil, errors.Wrap(err, "Could not get hold of the SSH session's stdout stream")
|
2019-09-05 18:55:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
stderr, err := b.Session.StderrPipe()
|
|
|
|
if err != nil {
|
2020-07-20 20:18:46 +00:00
|
|
|
return nil, nil, nil, nil, errors.Wrap(err, "Could not get hold of the SSH session's stderr stream")
|
2019-09-05 18:55:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
err = b.Session.Start(b.createCmd(cmd, args))
|
|
|
|
if err != nil {
|
2020-07-20 20:18:46 +00:00
|
|
|
return nil, nil, nil, nil, errors.Wrap(err, "Could not spawn process via SSH")
|
2019-09-05 18:55:53 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return b.Session, stdin, stdout, stderr, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *SSH) createCmd(cmd string, args []string) string {
|
|
|
|
parts := []string{cmd}
|
|
|
|
simple := regexp.MustCompile(`^[a-z0-9_/.~+-]+$`)
|
|
|
|
|
|
|
|
for _, arg := range args {
|
|
|
|
if !simple.MatchString(arg) {
|
|
|
|
arg = b.quote(arg)
|
|
|
|
}
|
|
|
|
|
|
|
|
parts = append(parts, arg)
|
|
|
|
}
|
|
|
|
|
|
|
|
return strings.Join(parts, " ")
|
|
|
|
}
|
|
|
|
|
|
|
|
func (b *SSH) quote(s string) string {
|
|
|
|
return fmt.Sprintf(`"%s"`, s)
|
|
|
|
}
|