237 lines
7.6 KiB
Go
237 lines
7.6 KiB
Go
package tsacmd
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"io/ioutil"
|
|
yaml "gopkg.in/yaml.v2"
|
|
|
|
"code.cloudfoundry.org/lager"
|
|
"github.com/concourse/concourse/atc"
|
|
"github.com/concourse/concourse/tsa"
|
|
"github.com/concourse/flag"
|
|
"github.com/tedsuo/ifrit"
|
|
"github.com/tedsuo/ifrit/grouper"
|
|
"github.com/tedsuo/ifrit/http_server"
|
|
"github.com/tedsuo/ifrit/sigmon"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type TSACommand struct {
|
|
Logger flag.Lager
|
|
|
|
BindIP flag.IP `long:"bind-ip" default:"0.0.0.0" description:"IP address on which to listen for SSH."`
|
|
PeerAddress string `long:"peer-address" default:"127.0.0.1" description:"Network address of this web node, reachable by other web nodes. Used for forwarded worker addresses."`
|
|
BindPort uint16 `long:"bind-port" default:"2222" description:"Port on which to listen for SSH."`
|
|
|
|
DebugBindIP flag.IP `long:"debug-bind-ip" default:"127.0.0.1" description:"IP address on which to listen for the pprof debugger endpoints."`
|
|
DebugBindPort uint16 `long:"debug-bind-port" default:"2221" description:"Port on which to listen for the pprof debugger endpoints."`
|
|
|
|
HostKey *flag.PrivateKey `long:"host-key" required:"true" description:"Path to private key to use for the SSH server."`
|
|
AuthorizedKeys flag.AuthorizedKeys `long:"authorized-keys" description:"Path to file containing keys to authorize, in SSH authorized_keys format (one public key per line)."`
|
|
TeamAuthorizedKeys map[string]flag.AuthorizedKeys `long:"team-authorized-keys" value-name:"NAME:PATH" description:"Path to file containing keys to authorize, in SSH authorized_keys format (one public key per line)."`
|
|
TeamAuthorizedKeysFile flag.File `long:"team-authorized-keys-file" description:"Path to file containing a YAML array of teams and their authorized SSH keys, e.g. [{team:foo,ssh_keys:[key1,key2]}]."`
|
|
|
|
ATCURLs []flag.URL `long:"atc-url" required:"true" description:"ATC API endpoints to which workers will be registered."`
|
|
|
|
SessionSigningKey *flag.PrivateKey `long:"session-signing-key" required:"true" description:"Path to private key to use when signing tokens in reqests to the ATC during registration."`
|
|
|
|
HeartbeatInterval time.Duration `long:"heartbeat-interval" default:"30s" description:"interval on which to heartbeat workers to the ATC"`
|
|
|
|
ClusterName string `long:"cluster-name" description:"A name for this Concourse cluster, to be displayed on the dashboard page."`
|
|
LogClusterName bool `long:"log-cluster-name" description:"Log cluster name."`
|
|
}
|
|
|
|
type TeamAuthKeys struct {
|
|
Team string
|
|
AuthKeys []ssh.PublicKey
|
|
}
|
|
|
|
type yamlTeamAuthorizedKey struct {
|
|
Team string `yaml:"team"`
|
|
Keys []string `yaml:"ssh_keys,flow"`
|
|
}
|
|
|
|
func (cmd *TSACommand) Execute(args []string) error {
|
|
runner, err := cmd.Runner(args)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tsaServerMember := grouper.Member{
|
|
Name: "tsa-server",
|
|
Runner: sigmon.New(runner),
|
|
}
|
|
|
|
tsaDebugMember := grouper.Member{
|
|
Name: "debug-server",
|
|
Runner: http_server.New(
|
|
cmd.debugBindAddr(),
|
|
http.DefaultServeMux,
|
|
)}
|
|
|
|
members := []grouper.Member{
|
|
tsaDebugMember,
|
|
tsaServerMember,
|
|
}
|
|
|
|
group := grouper.NewParallel(os.Interrupt, members)
|
|
return <-ifrit.Invoke(group).Wait()
|
|
}
|
|
|
|
func (cmd *TSACommand) Runner(args []string) (ifrit.Runner, error) {
|
|
logger, _ := cmd.constructLogger()
|
|
|
|
atcEndpointPicker := tsa.NewRandomATCEndpointPicker(cmd.ATCURLs)
|
|
|
|
teamAuthorizedKeys, err := cmd.loadTeamAuthorizedKeys()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load team authorized keys: %s", err)
|
|
}
|
|
|
|
if len(cmd.AuthorizedKeys.Keys)+len(cmd.TeamAuthorizedKeys) == 0 {
|
|
logger.Info("starting-tsa-without-authorized-keys")
|
|
}
|
|
|
|
sessionAuthTeam := &sessionTeam{
|
|
sessionTeams: make(map[string]string),
|
|
lock: &sync.RWMutex{},
|
|
}
|
|
|
|
config, err := cmd.configureSSHServer(sessionAuthTeam, cmd.AuthorizedKeys.Keys, teamAuthorizedKeys)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to configure SSH server: %s", err)
|
|
}
|
|
|
|
listenAddr := fmt.Sprintf("%s:%d", cmd.BindIP, cmd.BindPort)
|
|
|
|
if cmd.SessionSigningKey == nil {
|
|
return nil, fmt.Errorf("missing session signing key")
|
|
}
|
|
|
|
tokenGenerator := tsa.NewTokenGenerator(cmd.SessionSigningKey.PrivateKey)
|
|
|
|
server := &server{
|
|
logger: logger,
|
|
heartbeatInterval: cmd.HeartbeatInterval,
|
|
cprInterval: 1 * time.Second,
|
|
atcEndpointPicker: atcEndpointPicker,
|
|
tokenGenerator: tokenGenerator,
|
|
forwardHost: cmd.PeerAddress,
|
|
config: config,
|
|
httpClient: http.DefaultClient,
|
|
sessionTeam: sessionAuthTeam,
|
|
}
|
|
|
|
return serverRunner{logger, server, listenAddr}, nil
|
|
}
|
|
|
|
func (cmd *TSACommand) constructLogger() (lager.Logger, *lager.ReconfigurableSink) {
|
|
logger, reconfigurableSink := cmd.Logger.Logger("tsa")
|
|
if cmd.LogClusterName {
|
|
logger = logger.WithData(lager.Data{
|
|
"cluster": cmd.ClusterName,
|
|
})
|
|
}
|
|
|
|
return logger, reconfigurableSink
|
|
}
|
|
|
|
func (cmd *TSACommand) loadTeamAuthorizedKeys() ([]TeamAuthKeys, error) {
|
|
var teamKeys []TeamAuthKeys
|
|
|
|
for teamName, keys := range cmd.TeamAuthorizedKeys {
|
|
teamKeys = append(teamKeys, TeamAuthKeys{
|
|
Team: teamName,
|
|
AuthKeys: keys.Keys,
|
|
})
|
|
}
|
|
|
|
// load TeamAuthorizedKeysFile
|
|
if cmd.TeamAuthorizedKeysFile != "" {
|
|
logger, _ := cmd.constructLogger()
|
|
var rawTeamAuthorizedKeys []yamlTeamAuthorizedKey
|
|
|
|
authorizedKeysBytes, err := ioutil.ReadFile(cmd.TeamAuthorizedKeysFile.Path())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read yaml authorized keys file: %s", err)
|
|
}
|
|
err = yaml.Unmarshal([]byte(authorizedKeysBytes), &rawTeamAuthorizedKeys)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse yaml authorized keys file: %s", err)
|
|
}
|
|
|
|
for _, t := range rawTeamAuthorizedKeys {
|
|
var teamAuthorizedKeys []ssh.PublicKey
|
|
for _, k := range t.Keys {
|
|
key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
|
|
if err != nil {
|
|
logger.Error("load-team-authorized-keys-parse", fmt.Errorf("Invalid format, ignoring (%s): %s", k, err.Error()))
|
|
continue
|
|
}
|
|
logger.Info("load-team-authorized-keys-loaded", lager.Data{"team": t.Team, "key": k})
|
|
teamAuthorizedKeys = append(teamAuthorizedKeys, key)
|
|
}
|
|
teamKeys = append(teamKeys, TeamAuthKeys{Team: t.Team, AuthKeys: teamAuthorizedKeys})
|
|
}
|
|
}
|
|
|
|
return teamKeys, nil
|
|
}
|
|
|
|
func (cmd *TSACommand) configureSSHServer(sessionAuthTeam *sessionTeam, authorizedKeys []ssh.PublicKey, teamAuthorizedKeys []TeamAuthKeys) (*ssh.ServerConfig, error) {
|
|
certChecker := &ssh.CertChecker{
|
|
IsUserAuthority: func(key ssh.PublicKey) bool {
|
|
return false
|
|
},
|
|
|
|
IsHostAuthority: func(key ssh.PublicKey, address string) bool {
|
|
return false
|
|
},
|
|
|
|
UserKeyFallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
for _, k := range authorizedKeys {
|
|
if bytes.Equal(k.Marshal(), key.Marshal()) {
|
|
return nil, nil
|
|
}
|
|
}
|
|
|
|
for _, teamKeys := range teamAuthorizedKeys {
|
|
for _, k := range teamKeys.AuthKeys {
|
|
if bytes.Equal(k.Marshal(), key.Marshal()) {
|
|
sessionAuthTeam.AuthorizeTeam(string(conn.SessionID()), teamKeys.Team)
|
|
return nil, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("unknown public key")
|
|
},
|
|
}
|
|
|
|
config := &ssh.ServerConfig{
|
|
Config: atc.DefaultSSHConfig(),
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
return certChecker.Authenticate(conn, key)
|
|
},
|
|
}
|
|
|
|
signer, err := ssh.NewSignerFromKey(cmd.HostKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create signer from host key: %s", err)
|
|
}
|
|
|
|
config.AddHostKey(signer)
|
|
|
|
return config, nil
|
|
}
|
|
|
|
func (cmd *TSACommand) debugBindAddr() string {
|
|
return fmt.Sprintf("%s:%d", cmd.DebugBindIP, cmd.DebugBindPort)
|
|
}
|