diff options
41 files changed, 1373 insertions, 621 deletions
diff --git a/.ruby-version b/.ruby-version index 37c2961..a4dd9db 100644 --- a/.ruby-version +++ b/.ruby-version @@ -1 +1 @@ -2.7.2 +2.7.4 diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..7320730 --- /dev/null +++ b/.tool-versions @@ -0,0 +1,2 @@ +ruby 2.7.4 +go 1.16.8 @@ -1,13 +1,15 @@ .PHONY: validate verify verify_ruby verify_golang test test_ruby test_golang coverage coverage_golang setup _script_install build compile check clean install GO_SOURCES := $(shell find . -name '*.go') -VERSION_STRING := $(shell git describe --match v* 2>/dev/null || awk '$0="v"$0' VERSION 2>/dev/null || echo unknown) +VERSION_STRING := $(shell git describe --match v* 2>/dev/null || awk '$$0="v"$$0' VERSION 2>/dev/null || echo unknown) BUILD_TIME := $(shell date -u +%Y%m%d.%H%M%S) BUILD_TAGS := tracer_static tracer_static_jaeger continuous_profiler_stackdriver GOBUILD_FLAGS := -ldflags "-X main.Version=$(VERSION_STRING) -X main.BuildTime=$(BUILD_TIME)" -tags "$(BUILD_TAGS)" PREFIX ?= /usr/local +build: bin/gitlab-shell + validate: verify test verify: verify_golang @@ -40,7 +42,6 @@ setup: _script_install bin/gitlab-shell _script_install: bin/install -build: bin/gitlab-shell compile: bin/gitlab-shell bin/gitlab-shell: $(GO_SOURCES) GOBIN="$(CURDIR)/bin" go install $(GOBUILD_FLAGS) ./cmd/... @@ -98,6 +98,21 @@ environment. Starting with GitLab 8.12, GitLab supports Git LFS authentication through SSH. +## Logging Guidelines + +In general, it should be possible to determine the structure, but not content, +of a gitlab-shell or gitlab-sshd session just from inspecting the logs. Some +guidelines: + +- We use [`gitlab.com/gitlab-org/labkit/log`](https://pkg.go.dev/gitlab.com/gitlab-org/labkit/log) + for logging functionality +- **Always** include a correlation ID +- Log messages should be invariant and unique. Include accessory information in + fields, using `log.WithField`, `log.WithFields`, or `log.WithError`. +- Log success cases as well as error cases +- Logging too much is better than not logging enough. If a message seems too + verbose, consider reducing the log level before removing the message. + ## Releasing See [PROCESS.md](./PROCESS.md) diff --git a/client/httpclient.go b/client/httpclient.go index 9238824..7b8a35c 100644 --- a/client/httpclient.go +++ b/client/httpclient.go @@ -54,6 +54,22 @@ func WithClientCert(certPath, keyPath string) HTTPClientOpt { } } +func validateCaFile(filename string) error { + if filename == "" { + return nil + } + + if _, err := os.Stat(filename); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("cannot find cafile '%s': %w", filename, ErrCafileNotFound) + } + + return err + } + + return nil +} + // Deprecated: use NewHTTPClientWithOpts - https://gitlab.com/gitlab-org/gitlab-shell/-/issues/484 func NewHTTPClient(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64) *HttpClient { c, err := NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath, selfSignedCert, readTimeoutSeconds, nil) @@ -73,10 +89,8 @@ func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath stri } else if strings.HasPrefix(gitlabURL, httpProtocol) { transport, host = buildHttpTransport(gitlabURL) } else if strings.HasPrefix(gitlabURL, httpsProtocol) { - if _, err := os.Stat(caFile); err != nil { - if os.IsNotExist(err) { - return nil, fmt.Errorf("cannot find cafile '%s': %w", caFile, ErrCafileNotFound) - } + err = validateCaFile(caFile) + if err != nil { return nil, err } diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go index d2c2293..debe1bd 100644 --- a/client/httpsclient_test.go +++ b/client/httpsclient_test.go @@ -66,10 +66,11 @@ func TestSuccessfulRequests(t *testing.T) { func TestFailedRequests(t *testing.T) { testCases := []struct { - desc string - caFile string - caPath string - expectedError string + desc string + caFile string + caPath string + expectedCaFileNotFound bool + expectedError string }{ { desc: "Invalid CaFile", @@ -77,18 +78,25 @@ func TestFailedRequests(t *testing.T) { expectedError: "Internal API unreachable", }, { - desc: "Invalid CaPath", - caPath: path.Join(testhelper.TestRoot, "certs/invalid"), + desc: "Missing CaFile", + caFile: path.Join(testhelper.TestRoot, "certs/invalid/missing.crt"), + expectedCaFileNotFound: true, }, { - desc: "Empty config", + desc: "Invalid CaPath", + caPath: path.Join(testhelper.TestRoot, "certs/invalid"), + expectedError: "Internal API unreachable", + }, + { + desc: "Empty config", + expectedError: "Internal API unreachable", }, } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { client, err := setupWithRequests(t, tc.caFile, tc.caPath, "", "", "", false) - if tc.caFile == "" { + if tc.expectedCaFileNotFound { require.Error(t, err) require.ErrorIs(t, err, ErrCafileNotFound) } else { diff --git a/cmd/check/command/command.go b/cmd/check/command/command.go new file mode 100644 index 0000000..f260681 --- /dev/null +++ b/cmd/check/command/command.go @@ -0,0 +1,21 @@ +package command + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +func New(config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) { + if cmd := build(config, readWriter); cmd != nil { + return cmd, nil + } + + return nil, disallowedcommand.Error +} + +func build(config *config.Config, readWriter *readwriter.ReadWriter) command.Command { + return &healthcheck.Command{Config: config, ReadWriter: readWriter} +} diff --git a/cmd/check/command/command_test.go b/cmd/check/command/command_test.go new file mode 100644 index 0000000..cd06456 --- /dev/null +++ b/cmd/check/command/command_test.go @@ -0,0 +1,42 @@ +package command_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/cmd/check/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +var ( + basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"} +) + +func TestNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + config *config.Config + expectedType interface{} + }{ + { + desc: "it returns a Healthcheck command", + config: basicConfig, + expectedType: &healthcheck.Command{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + command, err := command.New(tc.config, nil) + + require.NoError(t, err) + require.IsType(t, tc.expectedType, command) + }) + } +} diff --git a/cmd/check/main.go b/cmd/check/main.go index 44d8175..e4bcdf2 100644 --- a/cmd/check/main.go +++ b/cmd/check/main.go @@ -4,12 +4,12 @@ import ( "fmt" "os" + checkCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/check/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/logger" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" ) func main() { @@ -19,7 +19,7 @@ func main() { ErrOut: os.Stderr, } - executable, err := executable.New(executable.Healthcheck, false) + executable, err := executable.New(executable.Healthcheck) if err != nil { fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting") os.Exit(1) @@ -34,7 +34,7 @@ func main() { logCloser := logger.Configure(config) defer logCloser.Close() - cmd, err := command.New(executable, os.Args[1:], sshenv.Env{}, config, readWriter) + cmd, err := checkCmd.New(config, readWriter) if err != nil { fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) diff --git a/cmd/gitlab-shell-authorized-keys-check/command/command.go b/cmd/gitlab-shell-authorized-keys-check/command/command.go new file mode 100644 index 0000000..8cf309b --- /dev/null +++ b/cmd/gitlab-shell-authorized-keys-check/command/command.go @@ -0,0 +1,37 @@ +package command + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +func New(arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) { + args, err := Parse(arguments) + if err != nil { + return nil, err + } + + if cmd := build(args, config, readWriter); cmd != nil { + return cmd, nil + } + + return nil, disallowedcommand.Error +} + +func Parse(arguments []string) (*commandargs.AuthorizedKeys, error) { + args := &commandargs.AuthorizedKeys{Arguments: arguments} + + if err := args.Parse(); err != nil { + return nil, err + } + + return args, nil +} + +func build(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) command.Command { + return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter} +} diff --git a/cmd/gitlab-shell-authorized-keys-check/command/command_test.go b/cmd/gitlab-shell-authorized-keys-check/command/command_test.go new file mode 100644 index 0000000..3343e1c --- /dev/null +++ b/cmd/gitlab-shell-authorized-keys-check/command/command_test.go @@ -0,0 +1,120 @@ +package command_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-keys-check/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +var ( + authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck} + basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"} +) + +func TestNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + config *config.Config + expectedType interface{} + }{ + { + desc: "it returns a AuthorizedKeys command", + executable: authorizedKeysExec, + arguments: []string{"git", "git", "key"}, + config: basicConfig, + expectedType: &authorizedkeys.Command{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + command, err := command.New(tc.arguments, tc.config, nil) + + require.NoError(t, err) + require.IsType(t, tc.expectedType, command) + }) + } +} + +func TestParseSuccess(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + expectedArgs commandargs.CommandArgs + expectError bool + }{ + { + desc: "It parses authorized-keys command", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"git", "git", "key"}, + expectedArgs: &commandargs.AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := command.Parse(tc.arguments) + + if !tc.expectError { + require.NoError(t, err) + require.Equal(t, tc.expectedArgs, result) + } else { + require.Error(t, err) + } + }) + } +} + +func TestParseFailure(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + expectedError string + }{ + { + desc: "With not enough arguments for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user"}, + expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", + }, + { + desc: "With too many arguments for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user", "user", "key", "something-else"}, + expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", + }, + { + desc: "With missing username for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user", "", "key"}, + expectedError: "# No username provided", + }, + { + desc: "With missing key for the AuthorizedKeysCheck", + executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, + arguments: []string{"user", "user", ""}, + expectedError: "# No key provided", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err := command.Parse(tc.arguments) + + require.EqualError(t, err, tc.expectedError) + }) + } +} diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go index cda3e0b..ebe6da9 100644 --- a/cmd/gitlab-shell-authorized-keys-check/main.go +++ b/cmd/gitlab-shell-authorized-keys-check/main.go @@ -4,13 +4,13 @@ import ( "fmt" "os" + cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-keys-check/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/console" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/logger" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" ) func main() { @@ -20,7 +20,7 @@ func main() { ErrOut: os.Stderr, } - executable, err := executable.New(executable.AuthorizedKeysCheck, true) + executable, err := executable.New(executable.AuthorizedKeysCheck) if err != nil { fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting") os.Exit(1) @@ -35,7 +35,7 @@ func main() { logCloser := logger.Configure(config) defer logCloser.Close() - cmd, err := command.New(executable, os.Args[1:], sshenv.Env{}, config, readWriter) + cmd, err := cmd.New(os.Args[1:], config, readWriter) if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment diff --git a/cmd/gitlab-shell-authorized-principals-check/command/command.go b/cmd/gitlab-shell-authorized-principals-check/command/command.go new file mode 100644 index 0000000..9418dad --- /dev/null +++ b/cmd/gitlab-shell-authorized-principals-check/command/command.go @@ -0,0 +1,37 @@ +package command + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +func New(arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) { + args, err := Parse(arguments) + if err != nil { + return nil, err + } + + if cmd := build(args, config, readWriter); cmd != nil { + return cmd, nil + } + + return nil, disallowedcommand.Error +} + +func Parse(arguments []string) (*commandargs.AuthorizedPrincipals, error) { + args := &commandargs.AuthorizedPrincipals{Arguments: arguments} + + if err := args.Parse(); err != nil { + return nil, err + } + + return args, nil +} + +func build(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) command.Command { + return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter} +} diff --git a/cmd/gitlab-shell-authorized-principals-check/command/command_test.go b/cmd/gitlab-shell-authorized-principals-check/command/command_test.go new file mode 100644 index 0000000..2ca2125 --- /dev/null +++ b/cmd/gitlab-shell-authorized-principals-check/command/command_test.go @@ -0,0 +1,114 @@ +package command_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-principals-check/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +var ( + authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck} + basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"} +) + +func TestNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + config *config.Config + expectedType interface{} + }{ + { + desc: "it returns a AuthorizedPrincipals command", + executable: authorizedPrincipalsExec, + arguments: []string{"key", "principal"}, + config: basicConfig, + expectedType: &authorizedprincipals.Command{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + command, err := cmd.New(tc.arguments, tc.config, nil) + + require.NoError(t, err) + require.IsType(t, tc.expectedType, command) + }) + } +} + +func TestParseSuccess(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + expectedArgs commandargs.CommandArgs + expectError bool + }{ + { + desc: "It parses authorized-principals command", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"key", "principal-1", "principal-2"}, + expectedArgs: &commandargs.AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := cmd.Parse(tc.arguments) + + if !tc.expectError { + require.NoError(t, err) + require.Equal(t, tc.expectedArgs, result) + } else { + require.Error(t, err) + } + }) + } +} + +func TestParseFailure(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + expectedError string + }{ + { + desc: "With not enough arguments for the AuthorizedPrincipalsCheck", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"key"}, + expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]", + }, + { + desc: "With missing key_id for the AuthorizedPrincipalsCheck", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"", "principal"}, + expectedError: "# No key_id provided", + }, + { + desc: "With blank principal for the AuthorizedPrincipalsCheck", + executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, + arguments: []string{"key", "principal", ""}, + expectedError: "# An invalid principal was provided", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err := cmd.Parse(tc.arguments) + + require.EqualError(t, err, tc.expectedError) + }) + } +} diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go index 87f7fa3..3e18b9d 100644 --- a/cmd/gitlab-shell-authorized-principals-check/main.go +++ b/cmd/gitlab-shell-authorized-principals-check/main.go @@ -4,13 +4,13 @@ import ( "fmt" "os" + cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-principals-check/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/console" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/logger" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" ) func main() { @@ -20,7 +20,7 @@ func main() { ErrOut: os.Stderr, } - executable, err := executable.New(executable.AuthorizedPrincipalsCheck, true) + executable, err := executable.New(executable.AuthorizedPrincipalsCheck) if err != nil { fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting") os.Exit(1) @@ -35,7 +35,7 @@ func main() { logCloser := logger.Configure(config) defer logCloser.Close() - cmd, err := command.New(executable, os.Args[1:], sshenv.Env{}, config, readWriter) + cmd, err := cmd.New(os.Args[1:], config, readWriter) if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment diff --git a/cmd/gitlab-shell/command/command.go b/cmd/gitlab-shell/command/command.go new file mode 100644 index 0000000..5f828cd --- /dev/null +++ b/cmd/gitlab-shell/command/command.go @@ -0,0 +1,78 @@ +package command + +import ( + "gitlab.com/gitlab-org/gitlab-shell/internal/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +func New(arguments []string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) { + args, err := Parse(arguments, env) + if err != nil { + return nil, err + } + + if cmd := Build(args, config, readWriter); cmd != nil { + return cmd, nil + } + + return nil, disallowedcommand.Error +} + +func NewWithKey(gitlabKeyId string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) { + args, err := Parse(nil, env) + if err != nil { + return nil, err + } + + args.GitlabKeyId = gitlabKeyId + if cmd := Build(args, config, readWriter); cmd != nil { + return cmd, nil + } + + return nil, disallowedcommand.Error +} + +func Parse(arguments []string, env sshenv.Env) (*commandargs.Shell, error) { + args := &commandargs.Shell{Arguments: arguments, Env: env} + + if err := args.Parse(); err != nil { + return nil, err + } + + return args, nil +} + +func Build(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) command.Command { + switch args.CommandType { + case commandargs.Discover: + return &discover.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.TwoFactorRecover: + return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.TwoFactorVerify: + return &twofactorverify.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.LfsAuthenticate: + return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.ReceivePack: + return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.UploadPack: + return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.UploadArchive: + return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.PersonalAccessToken: + return &personalaccesstoken.Command{Config: config, Args: args, ReadWriter: readWriter} + } + + return nil +} diff --git a/cmd/gitlab-shell/command/command_test.go b/cmd/gitlab-shell/command/command_test.go new file mode 100644 index 0000000..2aeee59 --- /dev/null +++ b/cmd/gitlab-shell/command/command_test.go @@ -0,0 +1,281 @@ +package command_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" +) + +var ( + gitlabShellExec = &executable.Executable{Name: executable.GitlabShell} + basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"} +) + +func TestNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + config *config.Config + expectedType interface{} + }{ + { + desc: "it returns a Discover command", + executable: gitlabShellExec, + env: buildEnv(""), + config: basicConfig, + expectedType: &discover.Command{}, + }, + { + desc: "it returns a TwoFactorRecover command", + executable: gitlabShellExec, + env: buildEnv("2fa_recovery_codes"), + config: basicConfig, + expectedType: &twofactorrecover.Command{}, + }, + { + desc: "it returns a TwoFactorVerify command", + executable: gitlabShellExec, + env: buildEnv("2fa_verify"), + config: basicConfig, + expectedType: &twofactorverify.Command{}, + }, + { + desc: "it returns an LfsAuthenticate command", + executable: gitlabShellExec, + env: buildEnv("git-lfs-authenticate"), + config: basicConfig, + expectedType: &lfsauthenticate.Command{}, + }, + { + desc: "it returns a ReceivePack command", + executable: gitlabShellExec, + env: buildEnv("git-receive-pack"), + config: basicConfig, + expectedType: &receivepack.Command{}, + }, + { + desc: "it returns an UploadPack command", + executable: gitlabShellExec, + env: buildEnv("git-upload-pack"), + config: basicConfig, + expectedType: &uploadpack.Command{}, + }, + { + desc: "it returns an UploadArchive command", + executable: gitlabShellExec, + env: buildEnv("git-upload-archive"), + config: basicConfig, + expectedType: &uploadarchive.Command{}, + }, + { + desc: "it returns a PersonalAccessToken command", + executable: gitlabShellExec, + env: buildEnv("personal_access_token"), + config: basicConfig, + expectedType: &personalaccesstoken.Command{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + command, err := cmd.New(tc.arguments, tc.env, tc.config, nil) + + require.NoError(t, err) + require.IsType(t, tc.expectedType, command) + }) + } +} + +func TestFailingNew(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + expectedError error + }{ + { + desc: "Parsing environment failed", + executable: gitlabShellExec, + expectedError: errors.New("Only SSH allowed"), + }, + { + desc: "Unknown command given", + executable: gitlabShellExec, + env: buildEnv("unknown"), + expectedError: disallowedcommand.Error, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + command, err := cmd.New([]string{}, tc.env, basicConfig, nil) + require.Nil(t, command) + require.Equal(t, tc.expectedError, err) + }) + } +} + +func buildEnv(command string) sshenv.Env { + return sshenv.Env{ + IsSSHConnection: true, + OriginalCommand: command, + } +} + +func TestParseSuccess(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + expectedArgs commandargs.CommandArgs + expectError bool + }{ + { + desc: "It sets discover as the command when the command string was empty", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: commandargs.Discover, Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, + }, + { + desc: "It finds the key id in any passed arguments", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, + arguments: []string{"hello", "key-123"}, + expectedArgs: &commandargs.Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: commandargs.Discover, GitlabKeyId: "123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, + }, + { + desc: "It finds the key id only if the argument is of <key-id> format", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, + arguments: []string{"hello", "username-key-123"}, + expectedArgs: &commandargs.Shell{Arguments: []string{"hello", "username-key-123"}, SshArgs: []string{}, CommandType: commandargs.Discover, GitlabUsername: "key-123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, + }, + { + desc: "It finds the username in any passed arguments", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, + arguments: []string{"hello", "username-jane-doe"}, + expectedArgs: &commandargs.Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: commandargs.Discover, GitlabUsername: "jane-doe", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, + }, + { + desc: "It parses 2fa_recovery_codes command", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: commandargs.TwoFactorRecover, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}}, + }, + { + desc: "It parses git-receive-pack command", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}}, + }, + { + desc: "It parses git-receive-pack command and a project with single quotes", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}}, + }, + { + desc: `It parses "git receive-pack" command`, + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}}, + }, + { + desc: `It parses a command followed by control characters`, + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}}, + }, + { + desc: "It parses git-upload-pack command", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: commandargs.UploadPack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}}, + }, + { + desc: "It parses git-upload-archive command", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: commandargs.UploadArchive, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}}, + }, + { + desc: "It parses git-lfs-authenticate command", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}, + arguments: []string{}, + expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: commandargs.LfsAuthenticate, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := cmd.Parse(tc.arguments, tc.env) + + if !tc.expectError { + require.NoError(t, err) + require.Equal(t, tc.expectedArgs, result) + } else { + require.Error(t, err) + } + }) + } +} + +func TestParseFailure(t *testing.T) { + testCases := []struct { + desc string + executable *executable.Executable + env sshenv.Env + arguments []string + expectedError string + }{ + { + desc: "It fails if SSH connection is not set", + executable: &executable.Executable{Name: executable.GitlabShell}, + arguments: []string{}, + expectedError: "Only SSH allowed", + }, + { + desc: "It fails if SSH command is invalid", + executable: &executable.Executable{Name: executable.GitlabShell}, + env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git receive-pack "`}, + arguments: []string{}, + expectedError: "Invalid SSH command: invalid command line string", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err := cmd.Parse(tc.arguments, tc.env) + + require.EqualError(t, err, tc.expectedError) + }) + } +} diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go index fe52bfc..a945d0c 100644 --- a/cmd/gitlab-shell/main.go +++ b/cmd/gitlab-shell/main.go @@ -3,7 +3,11 @@ package main import ( "fmt" "os" + "reflect" + "gitlab.com/gitlab-org/labkit/log" + + shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/config" @@ -34,7 +38,7 @@ func main() { ErrOut: os.Stderr, } - executable, err := executable.New(executable.GitlabShell, true) + executable, err := executable.New(executable.GitlabShell) if err != nil { fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting") os.Exit(1) @@ -50,7 +54,7 @@ func main() { defer logCloser.Close() env := sshenv.NewFromEnv() - cmd, err := command.New(executable, os.Args[1:], env, config, readWriter) + cmd, err := shellCmd.New(os.Args[1:], env, config, readWriter) if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment @@ -61,8 +65,15 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + cmdName := reflect.TypeOf(cmd).String() + ctxlog := log.ContextLogger(ctx) + ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("gitlab-shell: main: executing command") + + if err := cmd.Execute(ctx); err != nil { + ctxlog.WithError(err).Warn("gitlab-shell: main: command execution failed") console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } + + ctxlog.Info("gitlab-shell: main: command executed successfully") } diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go index 5bbf221..165c7a5 100644 --- a/cmd/gitlab-sshd/main.go +++ b/cmd/gitlab-sshd/main.go @@ -97,7 +97,7 @@ func main() { sig := <-done signal.Reset(syscall.SIGINT, syscall.SIGTERM) - log.WithFields(log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Info("Shutdown initiated") + log.WithContextFields(ctx, log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Info("Shutdown initiated") server.Shutdown() diff --git a/internal/command/command.go b/internal/command/command.go index dadf41a..4ee568e 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -3,23 +3,7 @@ package command import ( "context" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/executable" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/tracing" ) @@ -28,23 +12,6 @@ type Command interface { Execute(ctx context.Context) error } -func New(e *executable.Executable, arguments []string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) { - var args commandargs.CommandArgs - if e.AcceptArgs { - var err error - args, err = commandargs.Parse(e, arguments, env) - if err != nil { - return nil, err - } - } - - if cmd := buildCommand(e, args, config, readWriter); cmd != nil { - return cmd, nil - } - - return nil, disallowedcommand.Error -} - // Setup() initializes tracing from the configuration file and generates a // background context from which all other contexts in the process should derive // from, as it has a service name and initial correlation ID set. @@ -80,53 +47,3 @@ func Setup(serviceName string, config *config.Config) (context.Context, func()) closer.Close() } } - -func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command { - switch e.Name { - case executable.GitlabShell: - return BuildShellCommand(args.(*commandargs.Shell), config, readWriter) - case executable.AuthorizedKeysCheck: - return buildAuthorizedKeysCommand(args.(*commandargs.AuthorizedKeys), config, readWriter) - case executable.AuthorizedPrincipalsCheck: - return buildAuthorizedPrincipalsCommand(args.(*commandargs.AuthorizedPrincipals), config, readWriter) - case executable.Healthcheck: - return buildHealthcheckCommand(config, readWriter) - } - - return nil -} - -func BuildShellCommand(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) Command { - switch args.CommandType { - case commandargs.Discover: - return &discover.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.TwoFactorRecover: - return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.TwoFactorVerify: - return &twofactorverify.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.LfsAuthenticate: - return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.ReceivePack: - return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.UploadPack: - return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.UploadArchive: - return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter} - case commandargs.PersonalAccessToken: - return &personalaccesstoken.Command{Config: config, Args: args, ReadWriter: readWriter} - } - - return nil -} - -func buildAuthorizedKeysCommand(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) Command { - return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter} -} - -func buildAuthorizedPrincipalsCommand(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) Command { - return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter} -} - -func buildHealthcheckCommand(config *config.Config, readWriter *readwriter.ReadWriter) Command { - return &healthcheck.Command{Config: config, ReadWriter: readWriter} -} diff --git a/internal/command/command_test.go b/internal/command/command_test.go index a538745..2fc6655 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -1,173 +1,15 @@ package command import ( - "errors" "os" "testing" "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/executable" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" "gitlab.com/gitlab-org/labkit/correlation" ) -var ( - authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck, AcceptArgs: true} - authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck, AcceptArgs: true} - checkExec = &executable.Executable{Name: executable.Healthcheck, AcceptArgs: false} - gitlabShellExec = &executable.Executable{Name: executable.GitlabShell, AcceptArgs: true} - - basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"} - advancedConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket", SslCertDir: "/tmp/certs"} -) - -func buildEnv(command string) sshenv.Env { - return sshenv.Env{ - IsSSHConnection: true, - OriginalCommand: command, - } -} - -func TestNew(t *testing.T) { - testCases := []struct { - desc string - executable *executable.Executable - env sshenv.Env - arguments []string - config *config.Config - expectedType interface{} - }{ - { - desc: "it returns a Discover command", - executable: gitlabShellExec, - env: buildEnv(""), - config: basicConfig, - expectedType: &discover.Command{}, - }, - { - desc: "it returns a TwoFactorRecover command", - executable: gitlabShellExec, - env: buildEnv("2fa_recovery_codes"), - config: basicConfig, - expectedType: &twofactorrecover.Command{}, - }, - { - desc: "it returns a TwoFactorVerify command", - executable: gitlabShellExec, - env: buildEnv("2fa_verify"), - config: basicConfig, - expectedType: &twofactorverify.Command{}, - }, - { - desc: "it returns an LfsAuthenticate command", - executable: gitlabShellExec, - env: buildEnv("git-lfs-authenticate"), - config: basicConfig, - expectedType: &lfsauthenticate.Command{}, - }, - { - desc: "it returns a ReceivePack command", - executable: gitlabShellExec, - env: buildEnv("git-receive-pack"), - config: basicConfig, - expectedType: &receivepack.Command{}, - }, - { - desc: "it returns an UploadPack command", - executable: gitlabShellExec, - env: buildEnv("git-upload-pack"), - config: basicConfig, - expectedType: &uploadpack.Command{}, - }, - { - desc: "it returns an UploadArchive command", - executable: gitlabShellExec, - env: buildEnv("git-upload-archive"), - config: basicConfig, - expectedType: &uploadarchive.Command{}, - }, - { - desc: "it returns a Healthcheck command", - executable: checkExec, - config: basicConfig, - expectedType: &healthcheck.Command{}, - }, - { - desc: "it returns a AuthorizedKeys command", - executable: authorizedKeysExec, - arguments: []string{"git", "git", "key"}, - config: basicConfig, - expectedType: &authorizedkeys.Command{}, - }, - { - desc: "it returns a AuthorizedPrincipals command", - executable: authorizedPrincipalsExec, - arguments: []string{"key", "principal"}, - config: basicConfig, - expectedType: &authorizedprincipals.Command{}, - }, - { - desc: "it returns a PersonalAccessToken command", - executable: gitlabShellExec, - env: buildEnv("personal_access_token"), - config: basicConfig, - expectedType: &personalaccesstoken.Command{}, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - command, err := New(tc.executable, tc.arguments, tc.env, tc.config, nil) - - require.NoError(t, err) - require.IsType(t, tc.expectedType, command) - }) - } -} - -func TestFailingNew(t *testing.T) { - testCases := []struct { - desc string - executable *executable.Executable - env sshenv.Env - expectedError error - }{ - { - desc: "Parsing environment failed", - executable: gitlabShellExec, - expectedError: errors.New("Only SSH allowed"), - }, - { - desc: "Unknown command given", - executable: gitlabShellExec, - env: buildEnv("unknown"), - expectedError: disallowedcommand.Error, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - command, err := New(tc.executable, []string{}, tc.env, basicConfig, nil) - require.Nil(t, command) - require.Equal(t, tc.expectedError, err) - }) - } -} - func TestSetup(t *testing.T) { testCases := []struct { name string diff --git a/internal/command/commandargs/command_args.go b/internal/command/commandargs/command_args.go index a01b8b2..f23ba18 100644 --- a/internal/command/commandargs/command_args.go +++ b/internal/command/commandargs/command_args.go @@ -1,37 +1,8 @@ package commandargs -import ( - "errors" - "fmt" - - "gitlab.com/gitlab-org/gitlab-shell/internal/executable" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" -) - type CommandType string type CommandArgs interface { Parse() error GetArguments() []string } - -func Parse(e *executable.Executable, arguments []string, env sshenv.Env) (CommandArgs, error) { - var args CommandArgs - - switch e.Name { - case executable.GitlabShell: - args = &Shell{Arguments: arguments, Env: env} - case executable.AuthorizedKeysCheck: - args = &AuthorizedKeys{Arguments: arguments} - case executable.AuthorizedPrincipalsCheck: - args = &AuthorizedPrincipals{Arguments: arguments} - default: - return nil, errors.New(fmt.Sprintf("unknown executable: %s", e.Name)) - } - - if err := args.Parse(); err != nil { - return nil, err - } - - return args, nil -} diff --git a/internal/command/commandargs/command_args_test.go b/internal/command/commandargs/command_args_test.go deleted file mode 100644 index 119ecd4..0000000 --- a/internal/command/commandargs/command_args_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package commandargs - -import ( - "testing" - - "gitlab.com/gitlab-org/gitlab-shell/internal/executable" - "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" - - "github.com/stretchr/testify/require" -) - -func TestParseSuccess(t *testing.T) { - testCases := []struct { - desc string - executable *executable.Executable - env sshenv.Env - arguments []string - expectedArgs CommandArgs - expectError bool - }{ - { - desc: "It sets discover as the command when the command string was empty", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: Discover, Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, - }, { - desc: "It finds the key id in any passed arguments", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, - arguments: []string{"hello", "key-123"}, - expectedArgs: &Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, - }, { - desc: "It finds the key id only if the argument is of <key-id> format", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, - arguments: []string{"hello", "username-key-123"}, - expectedArgs: &Shell{Arguments: []string{"hello", "username-key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "key-123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, - }, { - desc: "It finds the username in any passed arguments", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}, - arguments: []string{"hello", "username-jane-doe"}, - expectedArgs: &Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}}, - }, { - desc: "It parses 2fa_recovery_codes command", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}}, - }, { - desc: "It parses git-receive-pack command", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}}, - }, { - desc: "It parses git-receive-pack command and a project with single quotes", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}}, - }, { - desc: `It parses "git receive-pack" command`, - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}}, - }, { - desc: `It parses a command followed by control characters`, - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}}, - }, { - desc: "It parses git-upload-pack command", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: UploadPack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}}, - }, { - desc: "It parses git-upload-archive command", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: UploadArchive, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}}, - }, { - desc: "It parses git-lfs-authenticate command", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}, - arguments: []string{}, - expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: LfsAuthenticate, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}}, - }, { - desc: "It parses authorized-keys command", - executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, - arguments: []string{"git", "git", "key"}, - expectedArgs: &AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"}, - }, { - desc: "It parses authorized-principals command", - executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, - arguments: []string{"key", "principal-1", "principal-2"}, - expectedArgs: &AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}}, - }, { - desc: "Unknown executable", - executable: &executable.Executable{Name: "unknown"}, - arguments: []string{}, - expectError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - result, err := Parse(tc.executable, tc.arguments, tc.env) - - if !tc.expectError { - require.NoError(t, err) - require.Equal(t, tc.expectedArgs, result) - } else { - require.Error(t, err) - } - }) - } -} - -func TestParseFailure(t *testing.T) { - testCases := []struct { - desc string - executable *executable.Executable - env sshenv.Env - arguments []string - expectedError string - }{ - { - desc: "It fails if SSH connection is not set", - executable: &executable.Executable{Name: executable.GitlabShell}, - arguments: []string{}, - expectedError: "Only SSH allowed", - }, - { - desc: "It fails if SSH command is invalid", - executable: &executable.Executable{Name: executable.GitlabShell}, - env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git receive-pack "`}, - arguments: []string{}, - expectedError: "Invalid SSH command", - }, - { - desc: "With not enough arguments for the AuthorizedKeysCheck", - executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, - arguments: []string{"user"}, - expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", - }, - { - desc: "With too many arguments for the AuthorizedKeysCheck", - executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, - arguments: []string{"user", "user", "key", "something-else"}, - expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>", - }, - { - desc: "With missing username for the AuthorizedKeysCheck", - executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, - arguments: []string{"user", "", "key"}, - expectedError: "# No username provided", - }, - { - desc: "With missing key for the AuthorizedKeysCheck", - executable: &executable.Executable{Name: executable.AuthorizedKeysCheck}, - arguments: []string{"user", "user", ""}, - expectedError: "# No key provided", - }, - { - desc: "With not enough arguments for the AuthorizedPrincipalsCheck", - executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, - arguments: []string{"key"}, - expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]", - }, - { - desc: "With missing key_id for the AuthorizedPrincipalsCheck", - executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, - arguments: []string{"", "principal"}, - expectedError: "# No key_id provided", - }, - { - desc: "With blank principal for the AuthorizedPrincipalsCheck", - executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}, - arguments: []string{"key", "principal", ""}, - expectedError: "# An invalid principal was provided", - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - _, err := Parse(tc.executable, tc.arguments, tc.env) - - require.EqualError(t, err, tc.expectedError) - }) - } -} diff --git a/internal/command/commandargs/shell.go b/internal/command/commandargs/shell.go index 589f58d..7a76be5 100644 --- a/internal/command/commandargs/shell.go +++ b/internal/command/commandargs/shell.go @@ -1,7 +1,7 @@ package commandargs import ( - "errors" + "fmt" "regexp" "github.com/mattn/go-shellwords" @@ -49,21 +49,16 @@ func (s *Shell) GetArguments() []string { func (s *Shell) validate() error { if !s.Env.IsSSHConnection { - return errors.New("Only SSH allowed") + return fmt.Errorf("Only SSH allowed") } - if !s.isValidSSHCommand() { - return errors.New("Invalid SSH command") + if err := s.ParseCommand(s.Env.OriginalCommand); err != nil { + return fmt.Errorf("Invalid SSH command: %w", err) } return nil } -func (s *Shell) isValidSSHCommand() bool { - err := s.ParseCommand(s.Env.OriginalCommand) - return err == nil -} - func (s *Shell) parseWho() { for _, argument := range s.Arguments { if keyId := tryParseKeyId(argument); keyId != "" { diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index dab69ab..ac3aafc 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -6,6 +6,8 @@ import ( "encoding/json" "fmt" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -58,6 +60,11 @@ func (c *Command) Execute(ctx context.Context) error { payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { // return nothing just like Ruby's GitlabShell#lfs_authenticate does + log.WithContextFields( + ctx, + log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId}, + ).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed") + return nil } diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go index 34086fb..73d2ce4 100644 --- a/internal/command/shared/customaction/customaction.go +++ b/internal/command/shared/customaction/customaction.go @@ -64,7 +64,7 @@ func (c *Command) processApiEndpoints(ctx context.Context, response *accessverif "endpoint": endpoint, } - log.WithFields(fields).Info("Performing custom action") + log.WithContextFields(ctx, fields).Info("Performing custom action") response, err := c.performRequest(ctx, client, endpoint, request) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index c67f60a..5185736 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -46,6 +46,7 @@ type Config struct { RootDir string LogFile string `yaml:"log_file,omitempty"` LogFormat string `yaml:"log_format,omitempty"` + LogLevel string `yaml:"log_level,omitempty"` GitlabUrl string `yaml:"gitlab_url"` GitlabRelativeURLRoot string `yaml:"gitlab_relative_url_root"` GitlabTracing string `yaml:"gitlab_tracing"` @@ -66,6 +67,7 @@ var ( DefaultConfig = Config{ LogFile: "gitlab-shell.log", LogFormat: "json", + LogLevel: "info", Server: DefaultServerConfig, User: "git", } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 699a261..78b2ed4 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -32,7 +32,7 @@ func TestHttpClient(t *testing.T) { client, err := config.HttpClient() require.NoError(t, err) - _, err = client.Get("http://host.com/path") + _, err = client.Get(url) require.NoError(t, err) ms, err := prometheus.DefaultGatherer.Gather() diff --git a/internal/executable/executable.go b/internal/executable/executable.go index 8b6b586..c6355b9 100644 --- a/internal/executable/executable.go +++ b/internal/executable/executable.go @@ -14,9 +14,8 @@ const ( ) type Executable struct { - Name string - RootDir string - AcceptArgs bool + Name string + RootDir string } var ( @@ -24,7 +23,7 @@ var ( osExecutable = os.Executable ) -func New(name string, acceptArgs bool) (*Executable, error) { +func New(name string) (*Executable, error) { path, err := osExecutable() if err != nil { return nil, err @@ -36,9 +35,8 @@ func New(name string, acceptArgs bool) (*Executable, error) { } executable := &Executable{ - Name: name, - RootDir: rootDir, - AcceptArgs: acceptArgs, + Name: name, + RootDir: rootDir, } return executable, nil diff --git a/internal/executable/executable_test.go b/internal/executable/executable_test.go index 71984c3..3915f1a 100644 --- a/internal/executable/executable_test.go +++ b/internal/executable/executable_test.go @@ -59,7 +59,7 @@ func TestNewSuccess(t *testing.T) { fake.Setup() defer fake.Cleanup() - result, err := New("gitlab-shell", true) + result, err := New("gitlab-shell") require.NoError(t, err) require.Equal(t, result.Name, "gitlab-shell") @@ -96,7 +96,7 @@ func TestNewFailure(t *testing.T) { fake.Setup() defer fake.Cleanup() - _, err := New("gitlab-shell", true) + _, err := New("gitlab-shell") require.Error(t, err) }) diff --git a/internal/handler/exec.go b/internal/handler/exec.go index 27031b1..172736d 100644 --- a/internal/handler/exec.go +++ b/internal/handler/exec.go @@ -9,7 +9,9 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "google.golang.org/grpc" + grpccodes "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + grpcstatus "google.golang.org/grpc/status" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" @@ -50,6 +52,12 @@ func (gc *GitalyCommand) RunGitalyCommand(ctx context.Context, handler GitalyHan childCtx := withOutgoingMetadata(ctx, gc.Features) _, err = handler(childCtx, conn) + if err != nil && grpcstatus.Convert(err).Code() == grpccodes.Unavailable { + log.WithError(err).Error("Gitaly is unavailable") + + return fmt.Errorf("Git service is temporarily unavailable") + } + return err } @@ -110,7 +118,7 @@ func getConn(ctx context.Context, gc *GitalyCommand) (*grpc.ClientConn, error) { if serviceName == "" { serviceName = "gitlab-shell-unknown" - log.WithFields(log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown") + log.WithContextFields(ctx, log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown") } serviceName = fmt.Sprintf("%s-%s", serviceName, gc.ServiceName) diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go index 5ad0675..1d714ef 100644 --- a/internal/handler/exec_test.go +++ b/internal/handler/exec_test.go @@ -7,7 +7,9 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" + grpccodes "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + grpcstatus "google.golang.org/grpc/status" pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" "gitlab.com/gitlab-org/gitlab-shell/internal/config" @@ -45,6 +47,18 @@ func TestMissingGitalyAddress(t *testing.T) { require.EqualError(t, err, "no gitaly_address given") } +func TestUnavailableGitalyErr(t *testing.T) { + cmd := GitalyCommand{ + Config: &config.Config{}, + Address: "tcp://localhost:9999", + } + + expectedErr := grpcstatus.Error(grpccodes.Unavailable, "error") + err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, expectedErr)) + + require.EqualError(t, err, "Git service is temporarily unavailable") +} + func TestRunGitalyCommandMetadata(t *testing.T) { tests := []struct { name string diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 1165680..748fce0 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -35,6 +35,7 @@ func buildOpts(cfg *config.Config) []log.LoggerOption { log.WithFormatter(logFmt(cfg.LogFormat)), log.WithOutputName(logFile(cfg.LogFile)), log.WithTimezone(time.UTC), + log.WithLogLevel(cfg.LogLevel), } } diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index bda36d9..4ea8c1f 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -3,7 +3,6 @@ package logger import ( "os" "regexp" - "strings" "testing" "github.com/stretchr/testify/require" @@ -26,12 +25,37 @@ func TestConfigure(t *testing.T) { defer closer.Close() log.Info("this is a test") + log.WithFields(log.Fields{}).Debug("debug log message") tmpFile.Close() data, err := os.ReadFile(tmpFile.Name()) require.NoError(t, err) - require.True(t, strings.Contains(string(data), `msg":"this is a test"`)) + require.Contains(t, string(data), `msg":"this is a test"`) + require.NotContains(t, string(data), `msg:":"debug log message"`) +} + +func TestConfigureWithDebugLogLevel(t *testing.T) { + tmpFile, err := os.CreateTemp(os.TempDir(), "logtest-") + require.NoError(t, err) + defer tmpFile.Close() + + config := config.Config{ + LogFile: tmpFile.Name(), + LogFormat: "json", + LogLevel: "debug", + } + + closer := Configure(&config) + defer closer.Close() + + log.WithFields(log.Fields{}).Debug("debug log message") + + tmpFile.Close() + + data, err := os.ReadFile(tmpFile.Name()) + require.NoError(t, err) + require.Contains(t, string(data), `msg":"debug log message"`) } func TestConfigureWithPermissionError(t *testing.T) { diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 0e0da93..f6d8fb5 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -29,21 +29,26 @@ func newConnection(maxSessions int64, remoteAddr string) *connection { } func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) { + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) + defer metrics.SshdConnectionDuration.Observe(time.Since(c.begin).Seconds()) for newChannel := range chans { + ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested") if newChannel.ChannelType() != "session" { + ctxlog.Info("connection: handle: unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } if !c.concurrentSessions.TryAcquire(1) { + ctxlog.Info("connection: handle: too many concurrent sessions") newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") metrics.SshdHitMaxSessions.Inc() continue } channel, requests, err := newChannel.Accept() if err != nil { - log.WithError(err).Info("could not accept channel") + ctxlog.WithError(err).Error("connection: handle: accepting channel failed") c.concurrentSessions.Release(1) continue } @@ -54,11 +59,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha // Prevent a panic in a single session from taking out the whole server defer func() { if err := recover(); err != nil { - log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", c.remoteAddr) + ctxlog.WithField("recovered_error", err).Warn("panic handling session") } }() handler(ctx, channel, requests) + ctxlog.Info("connection: handle: done") }() } } diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go new file mode 100644 index 0000000..68210f8 --- /dev/null +++ b/internal/sshd/server_config.go @@ -0,0 +1,94 @@ +package sshd + +import ( + "context" + "encoding/base64" + "fmt" + "os" + "strconv" + "time" + + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" + + "gitlab.com/gitlab-org/labkit/log" +) + +type serverConfig struct { + cfg *config.Config + hostKeys []ssh.Signer + authorizedKeysClient *authorizedkeys.Client +} + +func newServerConfig(cfg *config.Config) (*serverConfig, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + var hostKeys []ssh.Signer + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := os.ReadFile(filename) + if err != nil { + log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to read host key") + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to parse host key") + continue + } + + hostKeys = append(hostKeys, key) + } + if len(hostKeys) == 0 { + return nil, fmt.Errorf("No host keys could be loaded, aborting") + } + + return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil +} + +func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) { + if user != s.cfg.User { + return nil, fmt.Errorf("unknown user") + } + if key.Type() == ssh.KeyAlgoDSA { + return nil, fmt.Errorf("DSA is prohibited") + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + if err != nil { + return nil, err + } + + return res, nil +} + +func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig { + sshCfg := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + res, err := s.getAuthKey(ctx, conn.User(), key) + if err != nil { + return nil, err + } + + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "key-id": strconv.FormatInt(res.Id, 10), + }, + }, nil + }, + } + + for _, key := range s.hostKeys { + sshCfg.AddHostKey(key) + } + + return sshCfg +} diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go new file mode 100644 index 0000000..58bd3e1 --- /dev/null +++ b/internal/sshd/server_config_test.go @@ -0,0 +1,105 @@ +package sshd + +import ( + "context" + "crypto/dsa" + "crypto/rand" + "crypto/rsa" + "path" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestNewServerConfigWithoutHosts(t *testing.T) { + _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"}) + + require.Error(t, err) + require.Equal(t, "No host keys could be loaded, aborting", err.Error()) +} + +func TestFailedAuthorizedKeysClient(t *testing.T) { + _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"}) + + require.Error(t, err) + require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error()) +} + +func TestFailedGetAuthKey(t *testing.T) { + testhelper.PrepareTestRootDir(t) + + srvCfg := config.ServerConfig{ + Listen: "127.0.0.1", + ConcurrentSessionsLimit: 1, + HostKeyFiles: []string{ + path.Join(testhelper.TestRoot, "certs/valid/server.key"), + path.Join(testhelper.TestRoot, "certs/invalid-path.key"), + path.Join(testhelper.TestRoot, "certs/invalid/server.crt"), + }, + } + + cfg, err := newServerConfig( + &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg}, + ) + require.NoError(t, err) + + testCases := []struct { + desc string + user string + key ssh.PublicKey + expectedError string + }{ + { + desc: "wrong user", + user: "wrong-user", + key: rsaPublicKey(t), + expectedError: "unknown user", + }, { + desc: "prohibited dsa key", + user: "user", + key: dsaPublicKey(t), + expectedError: "DSA is prohibited", + }, { + desc: "API error", + user: "user", + key: rsaPublicKey(t), + expectedError: "Internal API unreachable", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key) + require.Error(t, err) + require.Equal(t, tc.expectedError, err.Error()) + }) + } +} + +func rsaPublicKey(t *testing.T) ssh.PublicKey { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + return publicKey +} + +func dsaPublicKey(t *testing.T) ssh.PublicKey { + privateKey := new(dsa.PrivateKey) + params := new(dsa.Parameters) + require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160)) + + privateKey.PublicKey.Parameters = *params + require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader)) + + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + require.NoError(t, err) + + return publicKey +} diff --git a/internal/sshd/session.go b/internal/sshd/session.go index 22cb715..b8e8625 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -2,13 +2,16 @@ package sshd import ( "context" + "errors" "fmt" + "reflect" + "gitlab.com/gitlab-org/labkit/log" "golang.org/x/crypto/ssh" - "gitlab.com/gitlab-org/gitlab-shell/internal/command" - "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" + shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" ) @@ -39,16 +42,27 @@ type exitStatusReq struct { } func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { + ctxlog := log.ContextLogger(ctx) + + ctxlog.Debug("session: handle: entering request loop") + for req := range requests { + sessionLog := ctxlog.WithFields(log.Fields{ + "bytesize": len(req.Payload), + "type": req.Type, + "want_reply": req.WantReply, + }) + sessionLog.Debug("session: handle: request received") + var shouldContinue bool switch req.Type { case "env": - shouldContinue = s.handleEnv(req) + shouldContinue = s.handleEnv(ctx, req) case "exec": shouldContinue = s.handleExec(ctx, req) case "shell": shouldContinue = false - s.exit(s.handleShell(ctx, req)) + s.exit(ctx, s.handleShell(ctx, req)) default: // Ignore unknown requests but don't terminate the session shouldContinue = true @@ -57,18 +71,23 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { } } + sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed") + if !shouldContinue { s.channel.Close() break } } + + ctxlog.Debug("session: handle: exiting request loop") } -func (s *session) handleEnv(req *ssh.Request) bool { +func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { var accepted bool var envRequest envRequest if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil { + log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request") return false } @@ -84,6 +103,10 @@ func (s *session) handleEnv(req *ssh.Request) bool { req.Reply(accepted, []byte{}) } + log.WithContextFields( + ctx, log.Fields{"accepted": accepted, "env_request": envRequest}, + ).Debug("session: handleEnv: processed") + return true } @@ -95,7 +118,8 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { s.execCmd = execRequest.Command - s.exit(s.handleShell(ctx, req)) + s.exit(ctx, s.handleShell(ctx, req)) + return false } @@ -104,19 +128,11 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { req.Reply(true, []byte{}) } - args := &commandargs.Shell{ - GitlabKeyId: s.gitlabKeyId, - Env: sshenv.Env{ - IsSSHConnection: true, - OriginalCommand: s.execCmd, - GitProtocolVersion: s.gitProtocolVersion, - RemoteAddr: s.remoteAddr, - }, - } - - if err := args.ParseCommand(s.execCmd); err != nil { - s.toStderr("Failed to parse command: %v\n", err.Error()) - return 128 + env := sshenv.Env{ + IsSSHConnection: true, + OriginalCommand: s.execCmd, + GitProtocolVersion: s.gitProtocolVersion, + RemoteAddr: s.remoteAddr, } rw := &readwriter.ReadWriter{ @@ -125,25 +141,37 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { ErrOut: s.channel.Stderr(), } - cmd := command.BuildShellCommand(args, s.cfg, rw) - if cmd == nil { - s.toStderr("Unknown command: %v\n", args.CommandType) + cmd, err := shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw) + if err != nil { + if !errors.Is(err, disallowedcommand.Error) { + s.toStderr(ctx, "Failed to parse command: %v\n", err.Error()) + } + s.toStderr(ctx, "Unknown command: %v\n", s.execCmd) return 128 } + cmdName := reflect.TypeOf(cmd).String() + ctxlog := log.ContextLogger(ctx) + ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command") + if err := cmd.Execute(ctx); err != nil { - s.toStderr("remote: ERROR: %v\n", err.Error()) + s.toStderr(ctx, "remote: ERROR: %v\n", err.Error()) return 1 } + ctxlog.Info("session: handleShell: command executed successfully") + return 0 } -func (s *session) toStderr(format string, args ...interface{}) { - fmt.Fprintf(s.channel.Stderr(), format, args...) +func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { + out := fmt.Sprintf(format, args...) + log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output") + fmt.Fprint(s.channel.Stderr(), out) } -func (s *session) exit(status uint32) { +func (s *session) exit(ctx context.Context, status uint32) { + log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting") req := exitStatusReq{ExitStatus: status} s.channel.CloseWrite() diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go new file mode 100644 index 0000000..f135825 --- /dev/null +++ b/internal/sshd/session_test.go @@ -0,0 +1,189 @@ +package sshd + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/config" +) + +type fakeChannel struct { + stdErr io.ReadWriter + sentRequestName string + sentRequestPayload []byte +} + +func (f *fakeChannel) Read(data []byte) (int, error) { + return 0, nil +} + +func (f *fakeChannel) Write(data []byte) (int, error) { + return 0, nil +} + +func (f *fakeChannel) Close() error { + return nil +} + +func (f *fakeChannel) CloseWrite() error { + return nil +} + +func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + f.sentRequestName = name + f.sentRequestPayload = payload + + return true, nil +} + +func (f *fakeChannel) Stderr() io.ReadWriter { + return f.stdErr +} + +var requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`)) + }, + }, +} + +func TestHandleEnv(t *testing.T) { + testCases := []struct { + desc string + payload []byte + expectedProtocolVersion string + expectedResult bool + }{ + { + desc: "invalid payload", + payload: []byte("invalid"), + expectedProtocolVersion: "1", + expectedResult: false, + }, { + desc: "valid payload", + payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}), + expectedProtocolVersion: "2", + expectedResult: true, + }, { + desc: "valid payload with forbidden env var", + payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}), + expectedProtocolVersion: "1", + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + s := &session{gitProtocolVersion: "1"} + r := &ssh.Request{Payload: tc.payload} + + require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult) + require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion) + }) + } +} + +func TestHandleExec(t *testing.T) { + testCases := []struct { + desc string + payload []byte + expectedExecCmd string + sentRequestName string + sentRequestPayload []byte + }{ + { + desc: "invalid payload", + payload: []byte("invalid"), + expectedExecCmd: "", + sentRequestName: "", + }, { + desc: "valid payload", + payload: ssh.Marshal(execRequest{Command: "discover"}), + expectedExecCmd: "discover", + sentRequestName: "exit-status", + sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}), + }, + } + + url := testserver.StartHttpServer(t, requests) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := &bytes.Buffer{} + f := &fakeChannel{stdErr: out} + s := &session{ + gitlabKeyId: "root", + channel: f, + cfg: &config.Config{GitlabUrl: url}, + } + r := &ssh.Request{Payload: tc.payload} + + require.Equal(t, false, s.handleExec(context.Background(), r)) + require.Equal(t, tc.sentRequestName, f.sentRequestName) + require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload) + }) + } +} + +func TestHandleShell(t *testing.T) { + testCases := []struct { + desc string + cmd string + errMsg string + gitlabKeyId string + expectedExitCode uint32 + }{ + { + desc: "fails to parse command", + cmd: `\`, + errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n", + gitlabKeyId: "root", + expectedExitCode: 128, + }, { + desc: "specified command is unknown", + cmd: "unknown-command", + errMsg: "Unknown command: unknown-command\n", + gitlabKeyId: "root", + expectedExitCode: 128, + }, { + desc: "fails to parse command", + cmd: "discover", + gitlabKeyId: "", + errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n", + expectedExitCode: 1, + }, { + desc: "fails to parse command", + cmd: "discover", + errMsg: "", + gitlabKeyId: "root", + expectedExitCode: 0, + }, + } + + url := testserver.StartHttpServer(t, requests) + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := &bytes.Buffer{} + s := &session{ + gitlabKeyId: tc.gitlabKeyId, + execCmd: tc.cmd, + channel: &fakeChannel{stdErr: out}, + cfg: &config.Config{GitlabUrl: url}, + } + r := &ssh.Request{} + + require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r)) + require.Equal(t, tc.errMsg, out.String()) + }) + } +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index de5fbd4..19fa661 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -2,13 +2,9 @@ package sshd import ( "context" - "encoding/base64" - "errors" "fmt" "net" "net/http" - "os" - "strconv" "sync" "time" @@ -16,7 +12,6 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" @@ -35,44 +30,24 @@ const ( type Server struct { Config *config.Config - status status - statusMu sync.Mutex - wg sync.WaitGroup - listener net.Listener - hostKeys []ssh.Signer - authorizedKeysClient *authorizedkeys.Client + status status + statusMu sync.Mutex + wg sync.WaitGroup + listener net.Listener + serverConfig *serverConfig } func NewServer(cfg *config.Config) (*Server, error) { - authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + serverConfig, err := newServerConfig(cfg) if err != nil { - return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + return nil, err } - var hostKeys []ssh.Signer - for _, filename := range cfg.Server.HostKeyFiles { - keyRaw, err := os.ReadFile(filename) - if err != nil { - log.WithError(err).Warnf("Failed to read host key %v", filename) - continue - } - key, err := ssh.ParsePrivateKey(keyRaw) - if err != nil { - log.WithError(err).Warnf("Failed to parse host key %v", filename) - continue - } - - hostKeys = append(hostKeys, key) - } - if len(hostKeys) == 0 { - return nil, fmt.Errorf("No host keys could be loaded, aborting") - } - - return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil + return &Server{Config: cfg, serverConfig: serverConfig}, nil } func (s *Server) ListenAndServe(ctx context.Context) error { - if err := s.listen(); err != nil { + if err := s.listen(ctx); err != nil { return err } defer s.listener.Close() @@ -110,7 +85,7 @@ func (s *Server) MonitoringServeMux() *http.ServeMux { return mux } -func (s *Server) listen() error { +func (s *Server) listen(ctx context.Context) error { sshListener, err := net.Listen("tcp", s.Config.Server.Listen) if err != nil { return fmt.Errorf("failed to listen for connection: %w", err) @@ -122,10 +97,10 @@ func (s *Server) listen() error { ReadHeaderTimeout: ProxyHeaderTimeout, } - log.Info("Proxy protocol is enabled") + log.ContextLogger(ctx).Info("Proxy protocol is enabled") } - log.WithFields(log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections") + log.WithContextFields(ctx, log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections") s.listener = sshListener @@ -142,7 +117,7 @@ func (s *Server) serve(ctx context.Context) { break } - log.WithError(err).Warn("Failed to accept connection") + log.ContextLogger(ctx).WithError(err).Warn("Failed to accept connection") continue } @@ -168,57 +143,29 @@ func (s *Server) getStatus() status { return s.status } -func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig { - sshCfg := &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if conn.User() != s.Config.User { - return nil, errors.New("unknown user") - } - if key.Type() == ssh.KeyAlgoDSA { - return nil, errors.New("DSA is prohibited") - } - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) - if err != nil { - return nil, err - } - - return &ssh.Permissions{ - // Record the public key used for authentication. - Extensions: map[string]string{ - "key-id": strconv.FormatInt(res.Id, 10), - }, - }, nil - }, - } - - for _, key := range s.hostKeys { - sshCfg.AddHostKey(key) - } - - return sshCfg -} - func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() defer s.wg.Done() defer nconn.Close() + ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) + defer cancel() + + ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr}) + // Prevent a panic in a single connection from taking out the whole server defer func() { if err := recover(); err != nil { - log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", remoteAddr) + ctxlog.Warn("panic handling session") } }() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) - defer cancel() + ctxlog.Info("server: handleConn: start") - sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx)) + sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx)) if err != nil { - log.WithError(err).Info("Failed to initialize SSH connection") + ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection") return } @@ -235,4 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { session.handle(ctx, requests) }) + + ctxlog.Info("server: handleConn: done") } diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 32946af..71f7733 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -104,6 +104,24 @@ func TestLivenessProbe(t *testing.T) { require.Equal(t, 200, r.Result().StatusCode) } +func TestInvalidClientConfig(t *testing.T) { + setupServer(t) + + cfg := clientConfig(t) + cfg.User = "unknown" + _, err := ssh.Dial("tcp", serverUrl, cfg) + require.Error(t, err) +} + +func TestInvalidServerConfig(t *testing.T) { + s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}} + err := s.ListenAndServe(context.Background()) + + require.Error(t, err) + require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error()) + require.Nil(t, s.Shutdown()) +} + func setupServer(t *testing.T) *Server { t.Helper() |