diff options
Diffstat (limited to 'internal/sshd')
-rw-r--r-- | internal/sshd/connection.go | 10 | ||||
-rw-r--r-- | internal/sshd/server_config.go | 94 | ||||
-rw-r--r-- | internal/sshd/server_config_test.go | 105 | ||||
-rw-r--r-- | internal/sshd/session.go | 80 | ||||
-rw-r--r-- | internal/sshd/session_test.go | 189 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 99 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 18 |
7 files changed, 492 insertions, 103 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 0e0da93..f6d8fb5 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -29,21 +29,26 @@ func newConnection(maxSessions int64, remoteAddr string) *connection { } func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) + defer metrics.SshdConnectionDuration.Observe(time.Since(c.begin).Seconds()) for newChannel := range chans { + ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested") if newChannel.ChannelType() != "session" { + ctxlog.Info("connection: handle: unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } if !c.concurrentSessions.TryAcquire(1) { + ctxlog.Info("connection: handle: too many concurrent sessions") newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") metrics.SshdHitMaxSessions.Inc() continue } channel, requests, err := newChannel.Accept() if err != nil { - log.WithError(err).Info("could not accept channel") + ctxlog.WithError(err).Error("connection: handle: accepting channel failed") c.concurrentSessions.Release(1) continue } @@ -54,11 +59,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha // Prevent a panic in a single session from taking out the whole server defer func() { if err := recover(); err != nil { - log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", c.remoteAddr) + ctxlog.WithField("recovered_error", err).Warn("panic handling session") } }() handler(ctx, channel, requests) + ctxlog.Info("connection: handle: done") }() } } diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go new file mode 100644 index 0000000..68210f8 --- /dev/null +++ b/internal/sshd/server_config.go @@ -0,0 +1,94 @@ +package sshd + +import ( + "context" + "encoding/base64" + "fmt" + "os" + "strconv" + "time" + + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" + + "gitlab.com/gitlab-org/labkit/log" +) + +type serverConfig struct { + cfg *config.Config + hostKeys []ssh.Signer + authorizedKeysClient *authorizedkeys.Client +} + +func newServerConfig(cfg *config.Config) (*serverConfig, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + var hostKeys []ssh.Signer + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := os.ReadFile(filename) + if err != nil { + log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to read host key") + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to parse host key") + continue + } + + hostKeys = append(hostKeys, key) + } + if len(hostKeys) == 0 { + return nil, fmt.Errorf("No host keys could be loaded, aborting") + } + + return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil +} + +func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) { + if user != s.cfg.User { + return nil, fmt.Errorf("unknown user") + } + if key.Type() == ssh.KeyAlgoDSA { + return nil, fmt.Errorf("DSA is prohibited") + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + if err != nil { + return nil, err + } + + return res, nil +} + +func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig { + sshCfg := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + res, err := s.getAuthKey(ctx, conn.User(), key) + if err != nil { + return nil, err + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "key-id": strconv.FormatInt(res.Id, 10), + }, + }, nil + }, + } + + for _, key := range s.hostKeys { + sshCfg.AddHostKey(key) + } + + return sshCfg +} diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go new file mode 100644 index 0000000..58bd3e1 --- /dev/null +++ b/internal/sshd/server_config_test.go @@ -0,0 +1,105 @@ +package sshd + +import ( + "context" + "crypto/dsa" + "crypto/rand" + "crypto/rsa" + "path" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestNewServerConfigWithoutHosts(t *testing.T) { + _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"}) + + require.Error(t, err) + require.Equal(t, "No host keys could be loaded, aborting", err.Error()) +} + +func TestFailedAuthorizedKeysClient(t *testing.T) { + _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"}) + + require.Error(t, err) + require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error()) +} + +func TestFailedGetAuthKey(t *testing.T) { + testhelper.PrepareTestRootDir(t) + + srvCfg := config.ServerConfig{ + Listen: "127.0.0.1", + ConcurrentSessionsLimit: 1, + HostKeyFiles: []string{ + path.Join(testhelper.TestRoot, "certs/valid/server.key"), + path.Join(testhelper.TestRoot, "certs/invalid-path.key"), + path.Join(testhelper.TestRoot, "certs/invalid/server.crt"), + }, + } + + cfg, err := newServerConfig( + &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg}, + ) + require.NoError(t, err) + + testCases := []struct { + desc string + user string + key ssh.PublicKey + expectedError string + }{ + { + desc: "wrong user", + user: "wrong-user", + key: rsaPublicKey(t), + expectedError: "unknown user", + }, { + desc: "prohibited dsa key", + user: "user", + key: dsaPublicKey(t), + expectedError: "DSA is prohibited", + }, { + desc: "API error", + user: "user", + key: rsaPublicKey(t), + expectedError: "Internal API unreachable", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key) + require.Error(t, err) + require.Equal(t, tc.expectedError, err.Error()) + }) + } +} + +func rsaPublicKey(t *testing.T) ssh.PublicKey { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + return publicKey +} + +func dsaPublicKey(t *testing.T) ssh.PublicKey { + privateKey := new(dsa.PrivateKey) + params := new(dsa.Parameters) + require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160)) + + privateKey.PublicKey.Parameters = *params + require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader)) + + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + return publicKey +} diff --git a/internal/sshd/session.go b/internal/sshd/session.go index 22cb715..b8e8625 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -2,13 +2,16 @@ package sshd import ( "context" + "errors" "fmt" + "reflect" + "gitlab.com/gitlab-org/labkit/log" "golang.org/x/crypto/ssh" - "gitlab.com/gitlab-org/gitlab-shell/internal/command" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" ) @@ -39,16 +42,27 @@ type exitStatusReq struct { } func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { + ctxlog := log.ContextLogger(ctx) + + ctxlog.Debug("session: handle: entering request loop") + for req := range requests { + sessionLog := ctxlog.WithFields(log.Fields{ + "bytesize": len(req.Payload), + "type": req.Type, + "want_reply": req.WantReply, + }) + sessionLog.Debug("session: handle: request received") + var shouldContinue bool switch req.Type { case "env": - shouldContinue = s.handleEnv(req) + shouldContinue = s.handleEnv(ctx, req) case "exec": shouldContinue = s.handleExec(ctx, req) case "shell": shouldContinue = false - s.exit(s.handleShell(ctx, req)) + s.exit(ctx, s.handleShell(ctx, req)) default: // Ignore unknown requests but don't terminate the session shouldContinue = true @@ -57,18 +71,23 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { } } + sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed") + if !shouldContinue { s.channel.Close() break } } + + ctxlog.Debug("session: handle: exiting request loop") } -func (s *session) handleEnv(req *ssh.Request) bool { +func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { var accepted bool var envRequest envRequest if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { + log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request") return false } @@ -84,6 +103,10 @@ func (s *session) handleEnv(req *ssh.Request) bool { req.Reply(accepted, []byte{}) } + log.WithContextFields( + ctx, log.Fields{"accepted": accepted, "env_request": envRequest}, + ).Debug("session: handleEnv: processed") + return true } @@ -95,7 +118,8 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { s.execCmd = execRequest.Command - s.exit(s.handleShell(ctx, req)) + s.exit(ctx, s.handleShell(ctx, req)) + return false } @@ -104,19 +128,11 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { req.Reply(true, []byte{}) } - args := &commandargs.Shell{ - GitlabKeyId: s.gitlabKeyId, - Env: sshenv.Env{ - IsSSHConnection: true, - OriginalCommand: s.execCmd, - GitProtocolVersion: s.gitProtocolVersion, - RemoteAddr: s.remoteAddr, - }, - } - - if err := args.ParseCommand(s.execCmd); err != nil { - s.toStderr("Failed to parse command: %v\n", err.Error()) - return 128 + env := sshenv.Env{ + IsSSHConnection: true, + OriginalCommand: s.execCmd, + GitProtocolVersion: s.gitProtocolVersion, + RemoteAddr: s.remoteAddr, } rw := &readwriter.ReadWriter{ @@ -125,25 +141,37 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { ErrOut: s.channel.Stderr(), } - cmd := command.BuildShellCommand(args, s.cfg, rw) - if cmd == nil { - s.toStderr("Unknown command: %v\n", args.CommandType) + cmd, err := shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw) + if err != nil { + if !errors.Is(err, disallowedcommand.Error) { + s.toStderr(ctx, "Failed to parse command: %v\n", err.Error()) + } + s.toStderr(ctx, "Unknown command: %v\n", s.execCmd) return 128 } + cmdName := reflect.TypeOf(cmd).String() + ctxlog := log.ContextLogger(ctx) + ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command") + if err := cmd.Execute(ctx); err != nil { - s.toStderr("remote: ERROR: %v\n", err.Error()) + s.toStderr(ctx, "remote: ERROR: %v\n", err.Error()) return 1 } + ctxlog.Info("session: handleShell: command executed successfully") + return 0 } -func (s *session) toStderr(format string, args ...interface{}) { - fmt.Fprintf(s.channel.Stderr(), format, args...) +func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { + out := fmt.Sprintf(format, args...) + log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output") + fmt.Fprint(s.channel.Stderr(), out) } -func (s *session) exit(status uint32) { +func (s *session) exit(ctx context.Context, status uint32) { + log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting") req := exitStatusReq{ExitStatus: status} s.channel.CloseWrite() diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go new file mode 100644 index 0000000..f135825 --- /dev/null +++ b/internal/sshd/session_test.go @@ -0,0 +1,189 @@ +package sshd + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +type fakeChannel struct { + stdErr io.ReadWriter + sentRequestName string + sentRequestPayload []byte +} + +func (f *fakeChannel) Read(data []byte) (int, error) { + return 0, nil +} + +func (f *fakeChannel) Write(data []byte) (int, error) { + return 0, nil +} + +func (f *fakeChannel) Close() error { + return nil +} + +func (f *fakeChannel) CloseWrite() error { + return nil +} + +func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + f.sentRequestName = name + f.sentRequestPayload = payload + + return true, nil +} + +func (f *fakeChannel) Stderr() io.ReadWriter { + return f.stdErr +} + +var requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`)) + }, + }, +} + +func TestHandleEnv(t *testing.T) { + testCases := []struct { + desc string + payload []byte + expectedProtocolVersion string + expectedResult bool + }{ + { + desc: "invalid payload", + payload: []byte("invalid"), + expectedProtocolVersion: "1", + expectedResult: false, + }, { + desc: "valid payload", + payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}), + expectedProtocolVersion: "2", + expectedResult: true, + }, { + desc: "valid payload with forbidden env var", + payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}), + expectedProtocolVersion: "1", + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + s := &session{gitProtocolVersion: "1"} + r := &ssh.Request{Payload: tc.payload} + + require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult) + require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion) + }) + } +} + +func TestHandleExec(t *testing.T) { + testCases := []struct { + desc string + payload []byte + expectedExecCmd string + sentRequestName string + sentRequestPayload []byte + }{ + { + desc: "invalid payload", + payload: []byte("invalid"), + expectedExecCmd: "", + sentRequestName: "", + }, { + desc: "valid payload", + payload: ssh.Marshal(execRequest{Command: "discover"}), + expectedExecCmd: "discover", + sentRequestName: "exit-status", + sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}), + }, + } + + url := testserver.StartHttpServer(t, requests) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := &bytes.Buffer{} + f := &fakeChannel{stdErr: out} + s := &session{ + gitlabKeyId: "root", + channel: f, + cfg: &config.Config{GitlabUrl: url}, + } + r := &ssh.Request{Payload: tc.payload} + + require.Equal(t, false, s.handleExec(context.Background(), r)) + require.Equal(t, tc.sentRequestName, f.sentRequestName) + require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload) + }) + } +} + +func TestHandleShell(t *testing.T) { + testCases := []struct { + desc string + cmd string + errMsg string + gitlabKeyId string + expectedExitCode uint32 + }{ + { + desc: "fails to parse command", + cmd: `\`, + errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", + gitlabKeyId: "root", + expectedExitCode: 128, + }, { + desc: "specified command is unknown", + cmd: "unknown-command", + errMsg: "Unknown command: unknown-command\n", + gitlabKeyId: "root", + expectedExitCode: 128, + }, { + desc: "fails to parse command", + cmd: "discover", + gitlabKeyId: "", + errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", + expectedExitCode: 1, + }, { + desc: "fails to parse command", + cmd: "discover", + errMsg: "", + gitlabKeyId: "root", + expectedExitCode: 0, + }, + } + + url := testserver.StartHttpServer(t, requests) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := &bytes.Buffer{} + s := &session{ + gitlabKeyId: tc.gitlabKeyId, + execCmd: tc.cmd, + channel: &fakeChannel{stdErr: out}, + cfg: &config.Config{GitlabUrl: url}, + } + r := &ssh.Request{} + + require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r)) + require.Equal(t, tc.errMsg, out.String()) + }) + } +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index de5fbd4..19fa661 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -2,13 +2,9 @@ package sshd import ( "context" - "encoding/base64" - "errors" "fmt" "net" "net/http" - "os" - "strconv" "sync" "time" @@ -16,7 +12,6 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" @@ -35,44 +30,24 @@ const ( type Server struct { Config *config.Config - status status - statusMu sync.Mutex - wg sync.WaitGroup - listener net.Listener - hostKeys []ssh.Signer - authorizedKeysClient *authorizedkeys.Client + status status + statusMu sync.Mutex + wg sync.WaitGroup + listener net.Listener + serverConfig *serverConfig } func NewServer(cfg *config.Config) (*Server, error) { - authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + serverConfig, err := newServerConfig(cfg) if err != nil { - return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + return nil, err } - var hostKeys []ssh.Signer - for _, filename := range cfg.Server.HostKeyFiles { - keyRaw, err := os.ReadFile(filename) - if err != nil { - log.WithError(err).Warnf("Failed to read host key %v", filename) - continue - } - key, err := ssh.ParsePrivateKey(keyRaw) - if err != nil { - log.WithError(err).Warnf("Failed to parse host key %v", filename) - continue - } - - hostKeys = append(hostKeys, key) - } - if len(hostKeys) == 0 { - return nil, fmt.Errorf("No host keys could be loaded, aborting") - } - - return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil + return &Server{Config: cfg, serverConfig: serverConfig}, nil } func (s *Server) ListenAndServe(ctx context.Context) error { - if err := s.listen(); err != nil { + if err := s.listen(ctx); err != nil { return err } defer s.listener.Close() @@ -110,7 +85,7 @@ func (s *Server) MonitoringServeMux() *http.ServeMux { return mux } -func (s *Server) listen() error { +func (s *Server) listen(ctx context.Context) error { sshListener, err := net.Listen("tcp", s.Config.Server.Listen) if err != nil { return fmt.Errorf("failed to listen for connection: %w", err) @@ -122,10 +97,10 @@ func (s *Server) listen() error { ReadHeaderTimeout: ProxyHeaderTimeout, } - log.Info("Proxy protocol is enabled") + log.ContextLogger(ctx).Info("Proxy protocol is enabled") } - log.WithFields(log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections") + log.WithContextFields(ctx, log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections") s.listener = sshListener @@ -142,7 +117,7 @@ func (s *Server) serve(ctx context.Context) { break } - log.WithError(err).Warn("Failed to accept connection") + log.ContextLogger(ctx).WithError(err).Warn("Failed to accept connection") continue } @@ -168,57 +143,29 @@ func (s *Server) getStatus() status { return s.status } -func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig { - sshCfg := &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if conn.User() != s.Config.User { - return nil, errors.New("unknown user") - } - if key.Type() == ssh.KeyAlgoDSA { - return nil, errors.New("DSA is prohibited") - } - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) - if err != nil { - return nil, err - } - - return &ssh.Permissions{ - // Record the public key used for authentication. - Extensions: map[string]string{ - "key-id": strconv.FormatInt(res.Id, 10), - }, - }, nil - }, - } - - for _, key := range s.hostKeys { - sshCfg.AddHostKey(key) - } - - return sshCfg -} - func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() defer s.wg.Done() defer nconn.Close() + ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) + defer cancel() + + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr}) + // Prevent a panic in a single connection from taking out the whole server defer func() { if err := recover(); err != nil { - log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", remoteAddr) + ctxlog.Warn("panic handling session") } }() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) - defer cancel() + ctxlog.Info("server: handleConn: start") - sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx)) + sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx)) if err != nil { - log.WithError(err).Info("Failed to initialize SSH connection") + ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection") return } @@ -235,4 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { session.handle(ctx, requests) }) + + ctxlog.Info("server: handleConn: done") } diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 32946af..71f7733 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -104,6 +104,24 @@ func TestLivenessProbe(t *testing.T) { require.Equal(t, 200, r.Result().StatusCode) } +func TestInvalidClientConfig(t *testing.T) { + setupServer(t) + + cfg := clientConfig(t) + cfg.User = "unknown" + _, err := ssh.Dial("tcp", serverUrl, cfg) + require.Error(t, err) +} + +func TestInvalidServerConfig(t *testing.T) { + s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}} + err := s.ListenAndServe(context.Background()) + + require.Error(t, err) + require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error()) + require.Nil(t, s.Shutdown()) +} + func setupServer(t *testing.T) *Server { t.Helper() |