summaryrefslogtreecommitdiff
path: root/internal/sshd
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd')
-rw-r--r--internal/sshd/connection.go10
-rw-r--r--internal/sshd/server_config.go94
-rw-r--r--internal/sshd/server_config_test.go105
-rw-r--r--internal/sshd/session.go80
-rw-r--r--internal/sshd/session_test.go189
-rw-r--r--internal/sshd/sshd.go99
-rw-r--r--internal/sshd/sshd_test.go18
7 files changed, 492 insertions, 103 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 0e0da93..f6d8fb5 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -29,21 +29,26 @@ func newConnection(maxSessions int64, remoteAddr string) *connection {
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+
defer metrics.SshdConnectionDuration.Observe(time.Since(c.begin).Seconds())
for newChannel := range chans {
+ ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
+ ctxlog.Info("connection: handle: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
+ ctxlog.Info("connection: handle: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
- log.WithError(err).Info("could not accept channel")
+ ctxlog.WithError(err).Error("connection: handle: accepting channel failed")
c.concurrentSessions.Release(1)
continue
}
@@ -54,11 +59,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
// Prevent a panic in a single session from taking out the whole server
defer func() {
if err := recover(); err != nil {
- log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", c.remoteAddr)
+ ctxlog.WithField("recovered_error", err).Warn("panic handling session")
}
}()
handler(ctx, channel, requests)
+ ctxlog.Info("connection: handle: done")
}()
}
}
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go
new file mode 100644
index 0000000..68210f8
--- /dev/null
+++ b/internal/sshd/server_config.go
@@ -0,0 +1,94 @@
+package sshd
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "os"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+type serverConfig struct {
+ cfg *config.Config
+ hostKeys []ssh.Signer
+ authorizedKeysClient *authorizedkeys.Client
+}
+
+func newServerConfig(cfg *config.Config) (*serverConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ var hostKeys []ssh.Signer
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := os.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to read host key")
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to parse host key")
+ continue
+ }
+
+ hostKeys = append(hostKeys, key)
+ }
+ if len(hostKeys) == 0 {
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+}
+
+func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) {
+ if user != s.cfg.User {
+ return nil, fmt.Errorf("unknown user")
+ }
+ if key.Type() == ssh.KeyAlgoDSA {
+ return nil, fmt.Errorf("DSA is prohibited")
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig {
+ sshCfg := &ssh.ServerConfig{
+ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ res, err := s.getAuthKey(ctx, conn.User(), key)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "key-id": strconv.FormatInt(res.Id, 10),
+ },
+ }, nil
+ },
+ }
+
+ for _, key := range s.hostKeys {
+ sshCfg.AddHostKey(key)
+ }
+
+ return sshCfg
+}
diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go
new file mode 100644
index 0000000..58bd3e1
--- /dev/null
+++ b/internal/sshd/server_config_test.go
@@ -0,0 +1,105 @@
+package sshd
+
+import (
+ "context"
+ "crypto/dsa"
+ "crypto/rand"
+ "crypto/rsa"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestNewServerConfigWithoutHosts(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "No host keys could be loaded, aborting", err.Error())
+}
+
+func TestFailedAuthorizedKeysClient(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error())
+}
+
+func TestFailedGetAuthKey(t *testing.T) {
+ testhelper.PrepareTestRootDir(t)
+
+ srvCfg := config.ServerConfig{
+ Listen: "127.0.0.1",
+ ConcurrentSessionsLimit: 1,
+ HostKeyFiles: []string{
+ path.Join(testhelper.TestRoot, "certs/valid/server.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid-path.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ },
+ }
+
+ cfg, err := newServerConfig(
+ &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg},
+ )
+ require.NoError(t, err)
+
+ testCases := []struct {
+ desc string
+ user string
+ key ssh.PublicKey
+ expectedError string
+ }{
+ {
+ desc: "wrong user",
+ user: "wrong-user",
+ key: rsaPublicKey(t),
+ expectedError: "unknown user",
+ }, {
+ desc: "prohibited dsa key",
+ user: "user",
+ key: dsaPublicKey(t),
+ expectedError: "DSA is prohibited",
+ }, {
+ desc: "API error",
+ user: "user",
+ key: rsaPublicKey(t),
+ expectedError: "Internal API unreachable",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key)
+ require.Error(t, err)
+ require.Equal(t, tc.expectedError, err.Error())
+ })
+ }
+}
+
+func rsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
+
+func dsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey := new(dsa.PrivateKey)
+ params := new(dsa.Parameters)
+ require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160))
+
+ privateKey.PublicKey.Parameters = *params
+ require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader))
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index 22cb715..b8e8625 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -2,13 +2,16 @@ package sshd
import (
"context"
+ "errors"
"fmt"
+ "reflect"
+ "gitlab.com/gitlab-org/labkit/log"
"golang.org/x/crypto/ssh"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
@@ -39,16 +42,27 @@ type exitStatusReq struct {
}
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+ ctxlog := log.ContextLogger(ctx)
+
+ ctxlog.Debug("session: handle: entering request loop")
+
for req := range requests {
+ sessionLog := ctxlog.WithFields(log.Fields{
+ "bytesize": len(req.Payload),
+ "type": req.Type,
+ "want_reply": req.WantReply,
+ })
+ sessionLog.Debug("session: handle: request received")
+
var shouldContinue bool
switch req.Type {
case "env":
- shouldContinue = s.handleEnv(req)
+ shouldContinue = s.handleEnv(ctx, req)
case "exec":
shouldContinue = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
- s.exit(s.handleShell(ctx, req))
+ s.exit(ctx, s.handleShell(ctx, req))
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
@@ -57,18 +71,23 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
}
+ sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed")
+
if !shouldContinue {
s.channel.Close()
break
}
}
+
+ ctxlog.Debug("session: handle: exiting request loop")
}
-func (s *session) handleEnv(req *ssh.Request) bool {
+func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
+ log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
return false
}
@@ -84,6 +103,10 @@ func (s *session) handleEnv(req *ssh.Request) bool {
req.Reply(accepted, []byte{})
}
+ log.WithContextFields(
+ ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
+ ).Debug("session: handleEnv: processed")
+
return true
}
@@ -95,7 +118,8 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
s.execCmd = execRequest.Command
- s.exit(s.handleShell(ctx, req))
+ s.exit(ctx, s.handleShell(ctx, req))
+
return false
}
@@ -104,19 +128,11 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
req.Reply(true, []byte{})
}
- args := &commandargs.Shell{
- GitlabKeyId: s.gitlabKeyId,
- Env: sshenv.Env{
- IsSSHConnection: true,
- OriginalCommand: s.execCmd,
- GitProtocolVersion: s.gitProtocolVersion,
- RemoteAddr: s.remoteAddr,
- },
- }
-
- if err := args.ParseCommand(s.execCmd); err != nil {
- s.toStderr("Failed to parse command: %v\n", err.Error())
- return 128
+ env := sshenv.Env{
+ IsSSHConnection: true,
+ OriginalCommand: s.execCmd,
+ GitProtocolVersion: s.gitProtocolVersion,
+ RemoteAddr: s.remoteAddr,
}
rw := &readwriter.ReadWriter{
@@ -125,25 +141,37 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
ErrOut: s.channel.Stderr(),
}
- cmd := command.BuildShellCommand(args, s.cfg, rw)
- if cmd == nil {
- s.toStderr("Unknown command: %v\n", args.CommandType)
+ cmd, err := shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
+ if err != nil {
+ if !errors.Is(err, disallowedcommand.Error) {
+ s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
+ }
+ s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
return 128
}
+ cmdName := reflect.TypeOf(cmd).String()
+ ctxlog := log.ContextLogger(ctx)
+ ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command")
+
if err := cmd.Execute(ctx); err != nil {
- s.toStderr("remote: ERROR: %v\n", err.Error())
+ s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
return 1
}
+ ctxlog.Info("session: handleShell: command executed successfully")
+
return 0
}
-func (s *session) toStderr(format string, args ...interface{}) {
- fmt.Fprintf(s.channel.Stderr(), format, args...)
+func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
+ out := fmt.Sprintf(format, args...)
+ log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output")
+ fmt.Fprint(s.channel.Stderr(), out)
}
-func (s *session) exit(status uint32) {
+func (s *session) exit(ctx context.Context, status uint32) {
+ log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
s.channel.CloseWrite()
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
new file mode 100644
index 0000000..f135825
--- /dev/null
+++ b/internal/sshd/session_test.go
@@ -0,0 +1,189 @@
+package sshd
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type fakeChannel struct {
+ stdErr io.ReadWriter
+ sentRequestName string
+ sentRequestPayload []byte
+}
+
+func (f *fakeChannel) Read(data []byte) (int, error) {
+ return 0, nil
+}
+
+func (f *fakeChannel) Write(data []byte) (int, error) {
+ return 0, nil
+}
+
+func (f *fakeChannel) Close() error {
+ return nil
+}
+
+func (f *fakeChannel) CloseWrite() error {
+ return nil
+}
+
+func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ f.sentRequestName = name
+ f.sentRequestPayload = payload
+
+ return true, nil
+}
+
+func (f *fakeChannel) Stderr() io.ReadWriter {
+ return f.stdErr
+}
+
+var requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`))
+ },
+ },
+}
+
+func TestHandleEnv(t *testing.T) {
+ testCases := []struct {
+ desc string
+ payload []byte
+ expectedProtocolVersion string
+ expectedResult bool
+ }{
+ {
+ desc: "invalid payload",
+ payload: []byte("invalid"),
+ expectedProtocolVersion: "1",
+ expectedResult: false,
+ }, {
+ desc: "valid payload",
+ payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedProtocolVersion: "2",
+ expectedResult: true,
+ }, {
+ desc: "valid payload with forbidden env var",
+ payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedProtocolVersion: "1",
+ expectedResult: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ s := &session{gitProtocolVersion: "1"}
+ r := &ssh.Request{Payload: tc.payload}
+
+ require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
+ require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ })
+ }
+}
+
+func TestHandleExec(t *testing.T) {
+ testCases := []struct {
+ desc string
+ payload []byte
+ expectedExecCmd string
+ sentRequestName string
+ sentRequestPayload []byte
+ }{
+ {
+ desc: "invalid payload",
+ payload: []byte("invalid"),
+ expectedExecCmd: "",
+ sentRequestName: "",
+ }, {
+ desc: "valid payload",
+ payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedExecCmd: "discover",
+ sentRequestName: "exit-status",
+ sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
+ },
+ }
+
+ url := testserver.StartHttpServer(t, requests)
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ f := &fakeChannel{stdErr: out}
+ s := &session{
+ gitlabKeyId: "root",
+ channel: f,
+ cfg: &config.Config{GitlabUrl: url},
+ }
+ r := &ssh.Request{Payload: tc.payload}
+
+ require.Equal(t, false, s.handleExec(context.Background(), r))
+ require.Equal(t, tc.sentRequestName, f.sentRequestName)
+ require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
+ })
+ }
+}
+
+func TestHandleShell(t *testing.T) {
+ testCases := []struct {
+ desc string
+ cmd string
+ errMsg string
+ gitlabKeyId string
+ expectedExitCode uint32
+ }{
+ {
+ desc: "fails to parse command",
+ cmd: `\`,
+ errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
+ gitlabKeyId: "root",
+ expectedExitCode: 128,
+ }, {
+ desc: "specified command is unknown",
+ cmd: "unknown-command",
+ errMsg: "Unknown command: unknown-command\n",
+ gitlabKeyId: "root",
+ expectedExitCode: 128,
+ }, {
+ desc: "fails to parse command",
+ cmd: "discover",
+ gitlabKeyId: "",
+ errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedExitCode: 1,
+ }, {
+ desc: "fails to parse command",
+ cmd: "discover",
+ errMsg: "",
+ gitlabKeyId: "root",
+ expectedExitCode: 0,
+ },
+ }
+
+ url := testserver.StartHttpServer(t, requests)
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ s := &session{
+ gitlabKeyId: tc.gitlabKeyId,
+ execCmd: tc.cmd,
+ channel: &fakeChannel{stdErr: out},
+ cfg: &config.Config{GitlabUrl: url},
+ }
+ r := &ssh.Request{}
+
+ require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ require.Equal(t, tc.errMsg, out.String())
+ })
+ }
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index de5fbd4..19fa661 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -2,13 +2,9 @@ package sshd
import (
"context"
- "encoding/base64"
- "errors"
"fmt"
"net"
"net/http"
- "os"
- "strconv"
"sync"
"time"
@@ -16,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
@@ -35,44 +30,24 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
- hostKeys []ssh.Signer
- authorizedKeysClient *authorizedkeys.Client
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ serverConfig *serverConfig
}
func NewServer(cfg *config.Config) (*Server, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ serverConfig, err := newServerConfig(cfg)
if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ return nil, err
}
- var hostKeys []ssh.Signer
- for _, filename := range cfg.Server.HostKeyFiles {
- keyRaw, err := os.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
-
- hostKeys = append(hostKeys, key)
- }
- if len(hostKeys) == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
-
- return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+ return &Server{Config: cfg, serverConfig: serverConfig}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
- if err := s.listen(); err != nil {
+ if err := s.listen(ctx); err != nil {
return err
}
defer s.listener.Close()
@@ -110,7 +85,7 @@ func (s *Server) MonitoringServeMux() *http.ServeMux {
return mux
}
-func (s *Server) listen() error {
+func (s *Server) listen(ctx context.Context) error {
sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
return fmt.Errorf("failed to listen for connection: %w", err)
@@ -122,10 +97,10 @@ func (s *Server) listen() error {
ReadHeaderTimeout: ProxyHeaderTimeout,
}
- log.Info("Proxy protocol is enabled")
+ log.ContextLogger(ctx).Info("Proxy protocol is enabled")
}
- log.WithFields(log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections")
+ log.WithContextFields(ctx, log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections")
s.listener = sshListener
@@ -142,7 +117,7 @@ func (s *Server) serve(ctx context.Context) {
break
}
- log.WithError(err).Warn("Failed to accept connection")
+ log.ContextLogger(ctx).WithError(err).Warn("Failed to accept connection")
continue
}
@@ -168,57 +143,29 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
- sshCfg := &ssh.ServerConfig{
- PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != s.Config.User {
- return nil, errors.New("unknown user")
- }
- if key.Type() == ssh.KeyAlgoDSA {
- return nil, errors.New("DSA is prohibited")
- }
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
- if err != nil {
- return nil, err
- }
-
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "key-id": strconv.FormatInt(res.Id, 10),
- },
- }, nil
- },
- }
-
- for _, key := range s.hostKeys {
- sshCfg.AddHostKey(key)
- }
-
- return sshCfg
-}
-
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
defer nconn.Close()
+ ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
+ defer cancel()
+
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
+
// Prevent a panic in a single connection from taking out the whole server
defer func() {
if err := recover(); err != nil {
- log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", remoteAddr)
+ ctxlog.Warn("panic handling session")
}
}()
- ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
- defer cancel()
+ ctxlog.Info("server: handleConn: start")
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
- log.WithError(err).Info("Failed to initialize SSH connection")
+ ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection")
return
}
@@ -235,4 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
session.handle(ctx, requests)
})
+
+ ctxlog.Info("server: handleConn: done")
}
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 32946af..71f7733 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -104,6 +104,24 @@ func TestLivenessProbe(t *testing.T) {
require.Equal(t, 200, r.Result().StatusCode)
}
+func TestInvalidClientConfig(t *testing.T) {
+ setupServer(t)
+
+ cfg := clientConfig(t)
+ cfg.User = "unknown"
+ _, err := ssh.Dial("tcp", serverUrl, cfg)
+ require.Error(t, err)
+}
+
+func TestInvalidServerConfig(t *testing.T) {
+ s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}}
+ err := s.ListenAndServe(context.Background())
+
+ require.Error(t, err)
+ require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error())
+ require.Nil(t, s.Shutdown())
+}
+
func setupServer(t *testing.T) *Server {
t.Helper()