summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.ruby-version2
-rw-r--r--.tool-versions2
-rw-r--r--Makefile5
-rw-r--r--README.md15
-rw-r--r--client/httpclient.go22
-rw-r--r--client/httpsclient_test.go24
-rw-r--r--cmd/check/command/command.go21
-rw-r--r--cmd/check/command/command_test.go42
-rw-r--r--cmd/check/main.go6
-rw-r--r--cmd/gitlab-shell-authorized-keys-check/command/command.go37
-rw-r--r--cmd/gitlab-shell-authorized-keys-check/command/command_test.go120
-rw-r--r--cmd/gitlab-shell-authorized-keys-check/main.go6
-rw-r--r--cmd/gitlab-shell-authorized-principals-check/command/command.go37
-rw-r--r--cmd/gitlab-shell-authorized-principals-check/command/command_test.go114
-rw-r--r--cmd/gitlab-shell-authorized-principals-check/main.go6
-rw-r--r--cmd/gitlab-shell/command/command.go78
-rw-r--r--cmd/gitlab-shell/command/command_test.go281
-rw-r--r--cmd/gitlab-shell/main.go17
-rw-r--r--cmd/gitlab-sshd/main.go2
-rw-r--r--internal/command/command.go83
-rw-r--r--internal/command/command_test.go158
-rw-r--r--internal/command/commandargs/command_args.go29
-rw-r--r--internal/command/commandargs/command_args_test.go197
-rw-r--r--internal/command/commandargs/shell.go13
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go7
-rw-r--r--internal/command/shared/customaction/customaction.go2
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/config/config_test.go2
-rw-r--r--internal/executable/executable.go12
-rw-r--r--internal/executable/executable_test.go4
-rw-r--r--internal/handler/exec.go10
-rw-r--r--internal/handler/exec_test.go14
-rw-r--r--internal/logger/logger.go1
-rw-r--r--internal/logger/logger_test.go28
-rw-r--r--internal/sshd/connection.go10
-rw-r--r--internal/sshd/server_config.go94
-rw-r--r--internal/sshd/server_config_test.go105
-rw-r--r--internal/sshd/session.go80
-rw-r--r--internal/sshd/session_test.go189
-rw-r--r--internal/sshd/sshd.go99
-rw-r--r--internal/sshd/sshd_test.go18
41 files changed, 1373 insertions, 621 deletions
diff --git a/.ruby-version b/.ruby-version
index 37c2961..a4dd9db 100644
--- a/.ruby-version
+++ b/.ruby-version
@@ -1 +1 @@
-2.7.2
+2.7.4
diff --git a/.tool-versions b/.tool-versions
new file mode 100644
index 0000000..7320730
--- /dev/null
+++ b/.tool-versions
@@ -0,0 +1,2 @@
+ruby 2.7.4
+go 1.16.8
diff --git a/Makefile b/Makefile
index 632c324..ae83d7c 100644
--- a/Makefile
+++ b/Makefile
@@ -1,13 +1,15 @@
.PHONY: validate verify verify_ruby verify_golang test test_ruby test_golang coverage coverage_golang setup _script_install build compile check clean install
GO_SOURCES := $(shell find . -name '*.go')
-VERSION_STRING := $(shell git describe --match v* 2>/dev/null || awk '$0="v"$0' VERSION 2>/dev/null || echo unknown)
+VERSION_STRING := $(shell git describe --match v* 2>/dev/null || awk '$$0="v"$$0' VERSION 2>/dev/null || echo unknown)
BUILD_TIME := $(shell date -u +%Y%m%d.%H%M%S)
BUILD_TAGS := tracer_static tracer_static_jaeger continuous_profiler_stackdriver
GOBUILD_FLAGS := -ldflags "-X main.Version=$(VERSION_STRING) -X main.BuildTime=$(BUILD_TIME)" -tags "$(BUILD_TAGS)"
PREFIX ?= /usr/local
+build: bin/gitlab-shell
+
validate: verify test
verify: verify_golang
@@ -40,7 +42,6 @@ setup: _script_install bin/gitlab-shell
_script_install:
bin/install
-build: bin/gitlab-shell
compile: bin/gitlab-shell
bin/gitlab-shell: $(GO_SOURCES)
GOBIN="$(CURDIR)/bin" go install $(GOBUILD_FLAGS) ./cmd/...
diff --git a/README.md b/README.md
index 7847377..a45b30d 100644
--- a/README.md
+++ b/README.md
@@ -98,6 +98,21 @@ environment.
Starting with GitLab 8.12, GitLab supports Git LFS authentication through SSH.
+## Logging Guidelines
+
+In general, it should be possible to determine the structure, but not content,
+of a gitlab-shell or gitlab-sshd session just from inspecting the logs. Some
+guidelines:
+
+- We use [`gitlab.com/gitlab-org/labkit/log`](https://pkg.go.dev/gitlab.com/gitlab-org/labkit/log)
+ for logging functionality
+- **Always** include a correlation ID
+- Log messages should be invariant and unique. Include accessory information in
+ fields, using `log.WithField`, `log.WithFields`, or `log.WithError`.
+- Log success cases as well as error cases
+- Logging too much is better than not logging enough. If a message seems too
+ verbose, consider reducing the log level before removing the message.
+
## Releasing
See [PROCESS.md](./PROCESS.md)
diff --git a/client/httpclient.go b/client/httpclient.go
index 9238824..7b8a35c 100644
--- a/client/httpclient.go
+++ b/client/httpclient.go
@@ -54,6 +54,22 @@ func WithClientCert(certPath, keyPath string) HTTPClientOpt {
}
}
+func validateCaFile(filename string) error {
+ if filename == "" {
+ return nil
+ }
+
+ if _, err := os.Stat(filename); err != nil {
+ if os.IsNotExist(err) {
+ return fmt.Errorf("cannot find cafile '%s': %w", filename, ErrCafileNotFound)
+ }
+
+ return err
+ }
+
+ return nil
+}
+
// Deprecated: use NewHTTPClientWithOpts - https://gitlab.com/gitlab-org/gitlab-shell/-/issues/484
func NewHTTPClient(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64) *HttpClient {
c, err := NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath, selfSignedCert, readTimeoutSeconds, nil)
@@ -73,10 +89,8 @@ func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath stri
} else if strings.HasPrefix(gitlabURL, httpProtocol) {
transport, host = buildHttpTransport(gitlabURL)
} else if strings.HasPrefix(gitlabURL, httpsProtocol) {
- if _, err := os.Stat(caFile); err != nil {
- if os.IsNotExist(err) {
- return nil, fmt.Errorf("cannot find cafile '%s': %w", caFile, ErrCafileNotFound)
- }
+ err = validateCaFile(caFile)
+ if err != nil {
return nil, err
}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
index d2c2293..debe1bd 100644
--- a/client/httpsclient_test.go
+++ b/client/httpsclient_test.go
@@ -66,10 +66,11 @@ func TestSuccessfulRequests(t *testing.T) {
func TestFailedRequests(t *testing.T) {
testCases := []struct {
- desc string
- caFile string
- caPath string
- expectedError string
+ desc string
+ caFile string
+ caPath string
+ expectedCaFileNotFound bool
+ expectedError string
}{
{
desc: "Invalid CaFile",
@@ -77,18 +78,25 @@ func TestFailedRequests(t *testing.T) {
expectedError: "Internal API unreachable",
},
{
- desc: "Invalid CaPath",
- caPath: path.Join(testhelper.TestRoot, "certs/invalid"),
+ desc: "Missing CaFile",
+ caFile: path.Join(testhelper.TestRoot, "certs/invalid/missing.crt"),
+ expectedCaFileNotFound: true,
},
{
- desc: "Empty config",
+ desc: "Invalid CaPath",
+ caPath: path.Join(testhelper.TestRoot, "certs/invalid"),
+ expectedError: "Internal API unreachable",
+ },
+ {
+ desc: "Empty config",
+ expectedError: "Internal API unreachable",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
client, err := setupWithRequests(t, tc.caFile, tc.caPath, "", "", "", false)
- if tc.caFile == "" {
+ if tc.expectedCaFileNotFound {
require.Error(t, err)
require.ErrorIs(t, err, ErrCafileNotFound)
} else {
diff --git a/cmd/check/command/command.go b/cmd/check/command/command.go
new file mode 100644
index 0000000..f260681
--- /dev/null
+++ b/cmd/check/command/command.go
@@ -0,0 +1,21 @@
+package command
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+func New(config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) {
+ if cmd := build(config, readWriter); cmd != nil {
+ return cmd, nil
+ }
+
+ return nil, disallowedcommand.Error
+}
+
+func build(config *config.Config, readWriter *readwriter.ReadWriter) command.Command {
+ return &healthcheck.Command{Config: config, ReadWriter: readWriter}
+}
diff --git a/cmd/check/command/command_test.go b/cmd/check/command/command_test.go
new file mode 100644
index 0000000..cd06456
--- /dev/null
+++ b/cmd/check/command/command_test.go
@@ -0,0 +1,42 @@
+package command_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/cmd/check/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+var (
+ basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
+)
+
+func TestNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ config *config.Config
+ expectedType interface{}
+ }{
+ {
+ desc: "it returns a Healthcheck command",
+ config: basicConfig,
+ expectedType: &healthcheck.Command{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ command, err := command.New(tc.config, nil)
+
+ require.NoError(t, err)
+ require.IsType(t, tc.expectedType, command)
+ })
+ }
+}
diff --git a/cmd/check/main.go b/cmd/check/main.go
index 44d8175..e4bcdf2 100644
--- a/cmd/check/main.go
+++ b/cmd/check/main.go
@@ -4,12 +4,12 @@ import (
"fmt"
"os"
+ checkCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/check/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/logger"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
func main() {
@@ -19,7 +19,7 @@ func main() {
ErrOut: os.Stderr,
}
- executable, err := executable.New(executable.Healthcheck, false)
+ executable, err := executable.New(executable.Healthcheck)
if err != nil {
fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting")
os.Exit(1)
@@ -34,7 +34,7 @@ func main() {
logCloser := logger.Configure(config)
defer logCloser.Close()
- cmd, err := command.New(executable, os.Args[1:], sshenv.Env{}, config, readWriter)
+ cmd, err := checkCmd.New(config, readWriter)
if err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
diff --git a/cmd/gitlab-shell-authorized-keys-check/command/command.go b/cmd/gitlab-shell-authorized-keys-check/command/command.go
new file mode 100644
index 0000000..8cf309b
--- /dev/null
+++ b/cmd/gitlab-shell-authorized-keys-check/command/command.go
@@ -0,0 +1,37 @@
+package command
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+func New(arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) {
+ args, err := Parse(arguments)
+ if err != nil {
+ return nil, err
+ }
+
+ if cmd := build(args, config, readWriter); cmd != nil {
+ return cmd, nil
+ }
+
+ return nil, disallowedcommand.Error
+}
+
+func Parse(arguments []string) (*commandargs.AuthorizedKeys, error) {
+ args := &commandargs.AuthorizedKeys{Arguments: arguments}
+
+ if err := args.Parse(); err != nil {
+ return nil, err
+ }
+
+ return args, nil
+}
+
+func build(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) command.Command {
+ return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter}
+}
diff --git a/cmd/gitlab-shell-authorized-keys-check/command/command_test.go b/cmd/gitlab-shell-authorized-keys-check/command/command_test.go
new file mode 100644
index 0000000..3343e1c
--- /dev/null
+++ b/cmd/gitlab-shell-authorized-keys-check/command/command_test.go
@@ -0,0 +1,120 @@
+package command_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-keys-check/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+var (
+ authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck}
+ basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
+)
+
+func TestNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ config *config.Config
+ expectedType interface{}
+ }{
+ {
+ desc: "it returns a AuthorizedKeys command",
+ executable: authorizedKeysExec,
+ arguments: []string{"git", "git", "key"},
+ config: basicConfig,
+ expectedType: &authorizedkeys.Command{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ command, err := command.New(tc.arguments, tc.config, nil)
+
+ require.NoError(t, err)
+ require.IsType(t, tc.expectedType, command)
+ })
+ }
+}
+
+func TestParseSuccess(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ expectedArgs commandargs.CommandArgs
+ expectError bool
+ }{
+ {
+ desc: "It parses authorized-keys command",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"git", "git", "key"},
+ expectedArgs: &commandargs.AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result, err := command.Parse(tc.arguments)
+
+ if !tc.expectError {
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedArgs, result)
+ } else {
+ require.Error(t, err)
+ }
+ })
+ }
+}
+
+func TestParseFailure(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ expectedError string
+ }{
+ {
+ desc: "With not enough arguments for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user"},
+ expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
+ },
+ {
+ desc: "With too many arguments for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user", "user", "key", "something-else"},
+ expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
+ },
+ {
+ desc: "With missing username for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user", "", "key"},
+ expectedError: "# No username provided",
+ },
+ {
+ desc: "With missing key for the AuthorizedKeysCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
+ arguments: []string{"user", "user", ""},
+ expectedError: "# No key provided",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err := command.Parse(tc.arguments)
+
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go
index cda3e0b..ebe6da9 100644
--- a/cmd/gitlab-shell-authorized-keys-check/main.go
+++ b/cmd/gitlab-shell-authorized-keys-check/main.go
@@ -4,13 +4,13 @@ import (
"fmt"
"os"
+ cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-keys-check/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/console"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/logger"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
func main() {
@@ -20,7 +20,7 @@ func main() {
ErrOut: os.Stderr,
}
- executable, err := executable.New(executable.AuthorizedKeysCheck, true)
+ executable, err := executable.New(executable.AuthorizedKeysCheck)
if err != nil {
fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting")
os.Exit(1)
@@ -35,7 +35,7 @@ func main() {
logCloser := logger.Configure(config)
defer logCloser.Close()
- cmd, err := command.New(executable, os.Args[1:], sshenv.Env{}, config, readWriter)
+ cmd, err := cmd.New(os.Args[1:], config, readWriter)
if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on
// the environment
diff --git a/cmd/gitlab-shell-authorized-principals-check/command/command.go b/cmd/gitlab-shell-authorized-principals-check/command/command.go
new file mode 100644
index 0000000..9418dad
--- /dev/null
+++ b/cmd/gitlab-shell-authorized-principals-check/command/command.go
@@ -0,0 +1,37 @@
+package command
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+func New(arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) {
+ args, err := Parse(arguments)
+ if err != nil {
+ return nil, err
+ }
+
+ if cmd := build(args, config, readWriter); cmd != nil {
+ return cmd, nil
+ }
+
+ return nil, disallowedcommand.Error
+}
+
+func Parse(arguments []string) (*commandargs.AuthorizedPrincipals, error) {
+ args := &commandargs.AuthorizedPrincipals{Arguments: arguments}
+
+ if err := args.Parse(); err != nil {
+ return nil, err
+ }
+
+ return args, nil
+}
+
+func build(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) command.Command {
+ return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter}
+}
diff --git a/cmd/gitlab-shell-authorized-principals-check/command/command_test.go b/cmd/gitlab-shell-authorized-principals-check/command/command_test.go
new file mode 100644
index 0000000..2ca2125
--- /dev/null
+++ b/cmd/gitlab-shell-authorized-principals-check/command/command_test.go
@@ -0,0 +1,114 @@
+package command_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-principals-check/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+var (
+ authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck}
+ basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
+)
+
+func TestNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ config *config.Config
+ expectedType interface{}
+ }{
+ {
+ desc: "it returns a AuthorizedPrincipals command",
+ executable: authorizedPrincipalsExec,
+ arguments: []string{"key", "principal"},
+ config: basicConfig,
+ expectedType: &authorizedprincipals.Command{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ command, err := cmd.New(tc.arguments, tc.config, nil)
+
+ require.NoError(t, err)
+ require.IsType(t, tc.expectedType, command)
+ })
+ }
+}
+
+func TestParseSuccess(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ expectedArgs commandargs.CommandArgs
+ expectError bool
+ }{
+ {
+ desc: "It parses authorized-principals command",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"key", "principal-1", "principal-2"},
+ expectedArgs: &commandargs.AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result, err := cmd.Parse(tc.arguments)
+
+ if !tc.expectError {
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedArgs, result)
+ } else {
+ require.Error(t, err)
+ }
+ })
+ }
+}
+
+func TestParseFailure(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ expectedError string
+ }{
+ {
+ desc: "With not enough arguments for the AuthorizedPrincipalsCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"key"},
+ expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]",
+ },
+ {
+ desc: "With missing key_id for the AuthorizedPrincipalsCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"", "principal"},
+ expectedError: "# No key_id provided",
+ },
+ {
+ desc: "With blank principal for the AuthorizedPrincipalsCheck",
+ executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
+ arguments: []string{"key", "principal", ""},
+ expectedError: "# An invalid principal was provided",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err := cmd.Parse(tc.arguments)
+
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go
index 87f7fa3..3e18b9d 100644
--- a/cmd/gitlab-shell-authorized-principals-check/main.go
+++ b/cmd/gitlab-shell-authorized-principals-check/main.go
@@ -4,13 +4,13 @@ import (
"fmt"
"os"
+ cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell-authorized-principals-check/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/console"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/logger"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
func main() {
@@ -20,7 +20,7 @@ func main() {
ErrOut: os.Stderr,
}
- executable, err := executable.New(executable.AuthorizedPrincipalsCheck, true)
+ executable, err := executable.New(executable.AuthorizedPrincipalsCheck)
if err != nil {
fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting")
os.Exit(1)
@@ -35,7 +35,7 @@ func main() {
logCloser := logger.Configure(config)
defer logCloser.Close()
- cmd, err := command.New(executable, os.Args[1:], sshenv.Env{}, config, readWriter)
+ cmd, err := cmd.New(os.Args[1:], config, readWriter)
if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on
// the environment
diff --git a/cmd/gitlab-shell/command/command.go b/cmd/gitlab-shell/command/command.go
new file mode 100644
index 0000000..5f828cd
--- /dev/null
+++ b/cmd/gitlab-shell/command/command.go
@@ -0,0 +1,78 @@
+package command
+
+import (
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+func New(arguments []string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) {
+ args, err := Parse(arguments, env)
+ if err != nil {
+ return nil, err
+ }
+
+ if cmd := Build(args, config, readWriter); cmd != nil {
+ return cmd, nil
+ }
+
+ return nil, disallowedcommand.Error
+}
+
+func NewWithKey(gitlabKeyId string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (command.Command, error) {
+ args, err := Parse(nil, env)
+ if err != nil {
+ return nil, err
+ }
+
+ args.GitlabKeyId = gitlabKeyId
+ if cmd := Build(args, config, readWriter); cmd != nil {
+ return cmd, nil
+ }
+
+ return nil, disallowedcommand.Error
+}
+
+func Parse(arguments []string, env sshenv.Env) (*commandargs.Shell, error) {
+ args := &commandargs.Shell{Arguments: arguments, Env: env}
+
+ if err := args.Parse(); err != nil {
+ return nil, err
+ }
+
+ return args, nil
+}
+
+func Build(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) command.Command {
+ switch args.CommandType {
+ case commandargs.Discover:
+ return &discover.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.TwoFactorRecover:
+ return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.TwoFactorVerify:
+ return &twofactorverify.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.LfsAuthenticate:
+ return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.ReceivePack:
+ return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.UploadPack:
+ return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.UploadArchive:
+ return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter}
+ case commandargs.PersonalAccessToken:
+ return &personalaccesstoken.Command{Config: config, Args: args, ReadWriter: readWriter}
+ }
+
+ return nil
+}
diff --git a/cmd/gitlab-shell/command/command_test.go b/cmd/gitlab-shell/command/command_test.go
new file mode 100644
index 0000000..2aeee59
--- /dev/null
+++ b/cmd/gitlab-shell/command/command_test.go
@@ -0,0 +1,281 @@
+package command_test
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ cmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
+)
+
+var (
+ gitlabShellExec = &executable.Executable{Name: executable.GitlabShell}
+ basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
+)
+
+func TestNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ config *config.Config
+ expectedType interface{}
+ }{
+ {
+ desc: "it returns a Discover command",
+ executable: gitlabShellExec,
+ env: buildEnv(""),
+ config: basicConfig,
+ expectedType: &discover.Command{},
+ },
+ {
+ desc: "it returns a TwoFactorRecover command",
+ executable: gitlabShellExec,
+ env: buildEnv("2fa_recovery_codes"),
+ config: basicConfig,
+ expectedType: &twofactorrecover.Command{},
+ },
+ {
+ desc: "it returns a TwoFactorVerify command",
+ executable: gitlabShellExec,
+ env: buildEnv("2fa_verify"),
+ config: basicConfig,
+ expectedType: &twofactorverify.Command{},
+ },
+ {
+ desc: "it returns an LfsAuthenticate command",
+ executable: gitlabShellExec,
+ env: buildEnv("git-lfs-authenticate"),
+ config: basicConfig,
+ expectedType: &lfsauthenticate.Command{},
+ },
+ {
+ desc: "it returns a ReceivePack command",
+ executable: gitlabShellExec,
+ env: buildEnv("git-receive-pack"),
+ config: basicConfig,
+ expectedType: &receivepack.Command{},
+ },
+ {
+ desc: "it returns an UploadPack command",
+ executable: gitlabShellExec,
+ env: buildEnv("git-upload-pack"),
+ config: basicConfig,
+ expectedType: &uploadpack.Command{},
+ },
+ {
+ desc: "it returns an UploadArchive command",
+ executable: gitlabShellExec,
+ env: buildEnv("git-upload-archive"),
+ config: basicConfig,
+ expectedType: &uploadarchive.Command{},
+ },
+ {
+ desc: "it returns a PersonalAccessToken command",
+ executable: gitlabShellExec,
+ env: buildEnv("personal_access_token"),
+ config: basicConfig,
+ expectedType: &personalaccesstoken.Command{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ command, err := cmd.New(tc.arguments, tc.env, tc.config, nil)
+
+ require.NoError(t, err)
+ require.IsType(t, tc.expectedType, command)
+ })
+ }
+}
+
+func TestFailingNew(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ expectedError error
+ }{
+ {
+ desc: "Parsing environment failed",
+ executable: gitlabShellExec,
+ expectedError: errors.New("Only SSH allowed"),
+ },
+ {
+ desc: "Unknown command given",
+ executable: gitlabShellExec,
+ env: buildEnv("unknown"),
+ expectedError: disallowedcommand.Error,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ command, err := cmd.New([]string{}, tc.env, basicConfig, nil)
+ require.Nil(t, command)
+ require.Equal(t, tc.expectedError, err)
+ })
+ }
+}
+
+func buildEnv(command string) sshenv.Env {
+ return sshenv.Env{
+ IsSSHConnection: true,
+ OriginalCommand: command,
+ }
+}
+
+func TestParseSuccess(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ expectedArgs commandargs.CommandArgs
+ expectError bool
+ }{
+ {
+ desc: "It sets discover as the command when the command string was empty",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: commandargs.Discover, Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
+ },
+ {
+ desc: "It finds the key id in any passed arguments",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
+ arguments: []string{"hello", "key-123"},
+ expectedArgs: &commandargs.Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: commandargs.Discover, GitlabKeyId: "123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
+ },
+ {
+ desc: "It finds the key id only if the argument is of <key-id> format",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
+ arguments: []string{"hello", "username-key-123"},
+ expectedArgs: &commandargs.Shell{Arguments: []string{"hello", "username-key-123"}, SshArgs: []string{}, CommandType: commandargs.Discover, GitlabUsername: "key-123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
+ },
+ {
+ desc: "It finds the username in any passed arguments",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
+ arguments: []string{"hello", "username-jane-doe"},
+ expectedArgs: &commandargs.Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: commandargs.Discover, GitlabUsername: "jane-doe", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
+ },
+ {
+ desc: "It parses 2fa_recovery_codes command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: commandargs.TwoFactorRecover, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}},
+ },
+ {
+ desc: "It parses git-receive-pack command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}},
+ },
+ {
+ desc: "It parses git-receive-pack command and a project with single quotes",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}},
+ },
+ {
+ desc: `It parses "git receive-pack" command`,
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}},
+ },
+ {
+ desc: `It parses a command followed by control characters`,
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: commandargs.ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}},
+ },
+ {
+ desc: "It parses git-upload-pack command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: commandargs.UploadPack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}},
+ },
+ {
+ desc: "It parses git-upload-archive command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: commandargs.UploadArchive, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}},
+ },
+ {
+ desc: "It parses git-lfs-authenticate command",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"},
+ arguments: []string{},
+ expectedArgs: &commandargs.Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: commandargs.LfsAuthenticate, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ result, err := cmd.Parse(tc.arguments, tc.env)
+
+ if !tc.expectError {
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedArgs, result)
+ } else {
+ require.Error(t, err)
+ }
+ })
+ }
+}
+
+func TestParseFailure(t *testing.T) {
+ testCases := []struct {
+ desc string
+ executable *executable.Executable
+ env sshenv.Env
+ arguments []string
+ expectedError string
+ }{
+ {
+ desc: "It fails if SSH connection is not set",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ arguments: []string{},
+ expectedError: "Only SSH allowed",
+ },
+ {
+ desc: "It fails if SSH command is invalid",
+ executable: &executable.Executable{Name: executable.GitlabShell},
+ env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git receive-pack "`},
+ arguments: []string{},
+ expectedError: "Invalid SSH command: invalid command line string",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err := cmd.Parse(tc.arguments, tc.env)
+
+ require.EqualError(t, err, tc.expectedError)
+ })
+ }
+}
diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go
index fe52bfc..a945d0c 100644
--- a/cmd/gitlab-shell/main.go
+++ b/cmd/gitlab-shell/main.go
@@ -3,7 +3,11 @@ package main
import (
"fmt"
"os"
+ "reflect"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
@@ -34,7 +38,7 @@ func main() {
ErrOut: os.Stderr,
}
- executable, err := executable.New(executable.GitlabShell, true)
+ executable, err := executable.New(executable.GitlabShell)
if err != nil {
fmt.Fprintln(readWriter.ErrOut, "Failed to determine executable, exiting")
os.Exit(1)
@@ -50,7 +54,7 @@ func main() {
defer logCloser.Close()
env := sshenv.NewFromEnv()
- cmd, err := command.New(executable, os.Args[1:], env, config, readWriter)
+ cmd, err := shellCmd.New(os.Args[1:], env, config, readWriter)
if err != nil {
// For now this could happen if `SSH_CONNECTION` is not set on
// the environment
@@ -61,8 +65,15 @@ func main() {
ctx, finished := command.Setup(executable.Name, config)
defer finished()
- if err = cmd.Execute(ctx); err != nil {
+ cmdName := reflect.TypeOf(cmd).String()
+ ctxlog := log.ContextLogger(ctx)
+ ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("gitlab-shell: main: executing command")
+
+ if err := cmd.Execute(ctx); err != nil {
+ ctxlog.WithError(err).Warn("gitlab-shell: main: command execution failed")
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
+
+ ctxlog.Info("gitlab-shell: main: command executed successfully")
}
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go
index 5bbf221..165c7a5 100644
--- a/cmd/gitlab-sshd/main.go
+++ b/cmd/gitlab-sshd/main.go
@@ -97,7 +97,7 @@ func main() {
sig := <-done
signal.Reset(syscall.SIGINT, syscall.SIGTERM)
- log.WithFields(log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Info("Shutdown initiated")
+ log.WithContextFields(ctx, log.Fields{"shutdown_timeout_s": cfg.Server.GracePeriodSeconds, "signal": sig.String()}).Info("Shutdown initiated")
server.Shutdown()
diff --git a/internal/command/command.go b/internal/command/command.go
index dadf41a..4ee568e 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -3,23 +3,7 @@ package command
import (
"context"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/tracing"
)
@@ -28,23 +12,6 @@ type Command interface {
Execute(ctx context.Context) error
}
-func New(e *executable.Executable, arguments []string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
- var args commandargs.CommandArgs
- if e.AcceptArgs {
- var err error
- args, err = commandargs.Parse(e, arguments, env)
- if err != nil {
- return nil, err
- }
- }
-
- if cmd := buildCommand(e, args, config, readWriter); cmd != nil {
- return cmd, nil
- }
-
- return nil, disallowedcommand.Error
-}
-
// Setup() initializes tracing from the configuration file and generates a
// background context from which all other contexts in the process should derive
// from, as it has a service name and initial correlation ID set.
@@ -80,53 +47,3 @@ func Setup(serviceName string, config *config.Config) (context.Context, func())
closer.Close()
}
}
-
-func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- switch e.Name {
- case executable.GitlabShell:
- return BuildShellCommand(args.(*commandargs.Shell), config, readWriter)
- case executable.AuthorizedKeysCheck:
- return buildAuthorizedKeysCommand(args.(*commandargs.AuthorizedKeys), config, readWriter)
- case executable.AuthorizedPrincipalsCheck:
- return buildAuthorizedPrincipalsCommand(args.(*commandargs.AuthorizedPrincipals), config, readWriter)
- case executable.Healthcheck:
- return buildHealthcheckCommand(config, readWriter)
- }
-
- return nil
-}
-
-func BuildShellCommand(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- switch args.CommandType {
- case commandargs.Discover:
- return &discover.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.TwoFactorRecover:
- return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.TwoFactorVerify:
- return &twofactorverify.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.LfsAuthenticate:
- return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.ReceivePack:
- return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.UploadPack:
- return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.UploadArchive:
- return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.PersonalAccessToken:
- return &personalaccesstoken.Command{Config: config, Args: args, ReadWriter: readWriter}
- }
-
- return nil
-}
-
-func buildAuthorizedKeysCommand(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter}
-}
-
-func buildAuthorizedPrincipalsCommand(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter}
-}
-
-func buildHealthcheckCommand(config *config.Config, readWriter *readwriter.ReadWriter) Command {
- return &healthcheck.Command{Config: config, ReadWriter: readWriter}
-}
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index a538745..2fc6655 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -1,173 +1,15 @@
package command
import (
- "errors"
"os"
"testing"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
"gitlab.com/gitlab-org/labkit/correlation"
)
-var (
- authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck, AcceptArgs: true}
- authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck, AcceptArgs: true}
- checkExec = &executable.Executable{Name: executable.Healthcheck, AcceptArgs: false}
- gitlabShellExec = &executable.Executable{Name: executable.GitlabShell, AcceptArgs: true}
-
- basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
- advancedConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket", SslCertDir: "/tmp/certs"}
-)
-
-func buildEnv(command string) sshenv.Env {
- return sshenv.Env{
- IsSSHConnection: true,
- OriginalCommand: command,
- }
-}
-
-func TestNew(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- arguments []string
- config *config.Config
- expectedType interface{}
- }{
- {
- desc: "it returns a Discover command",
- executable: gitlabShellExec,
- env: buildEnv(""),
- config: basicConfig,
- expectedType: &discover.Command{},
- },
- {
- desc: "it returns a TwoFactorRecover command",
- executable: gitlabShellExec,
- env: buildEnv("2fa_recovery_codes"),
- config: basicConfig,
- expectedType: &twofactorrecover.Command{},
- },
- {
- desc: "it returns a TwoFactorVerify command",
- executable: gitlabShellExec,
- env: buildEnv("2fa_verify"),
- config: basicConfig,
- expectedType: &twofactorverify.Command{},
- },
- {
- desc: "it returns an LfsAuthenticate command",
- executable: gitlabShellExec,
- env: buildEnv("git-lfs-authenticate"),
- config: basicConfig,
- expectedType: &lfsauthenticate.Command{},
- },
- {
- desc: "it returns a ReceivePack command",
- executable: gitlabShellExec,
- env: buildEnv("git-receive-pack"),
- config: basicConfig,
- expectedType: &receivepack.Command{},
- },
- {
- desc: "it returns an UploadPack command",
- executable: gitlabShellExec,
- env: buildEnv("git-upload-pack"),
- config: basicConfig,
- expectedType: &uploadpack.Command{},
- },
- {
- desc: "it returns an UploadArchive command",
- executable: gitlabShellExec,
- env: buildEnv("git-upload-archive"),
- config: basicConfig,
- expectedType: &uploadarchive.Command{},
- },
- {
- desc: "it returns a Healthcheck command",
- executable: checkExec,
- config: basicConfig,
- expectedType: &healthcheck.Command{},
- },
- {
- desc: "it returns a AuthorizedKeys command",
- executable: authorizedKeysExec,
- arguments: []string{"git", "git", "key"},
- config: basicConfig,
- expectedType: &authorizedkeys.Command{},
- },
- {
- desc: "it returns a AuthorizedPrincipals command",
- executable: authorizedPrincipalsExec,
- arguments: []string{"key", "principal"},
- config: basicConfig,
- expectedType: &authorizedprincipals.Command{},
- },
- {
- desc: "it returns a PersonalAccessToken command",
- executable: gitlabShellExec,
- env: buildEnv("personal_access_token"),
- config: basicConfig,
- expectedType: &personalaccesstoken.Command{},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- command, err := New(tc.executable, tc.arguments, tc.env, tc.config, nil)
-
- require.NoError(t, err)
- require.IsType(t, tc.expectedType, command)
- })
- }
-}
-
-func TestFailingNew(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- expectedError error
- }{
- {
- desc: "Parsing environment failed",
- executable: gitlabShellExec,
- expectedError: errors.New("Only SSH allowed"),
- },
- {
- desc: "Unknown command given",
- executable: gitlabShellExec,
- env: buildEnv("unknown"),
- expectedError: disallowedcommand.Error,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- command, err := New(tc.executable, []string{}, tc.env, basicConfig, nil)
- require.Nil(t, command)
- require.Equal(t, tc.expectedError, err)
- })
- }
-}
-
func TestSetup(t *testing.T) {
testCases := []struct {
name string
diff --git a/internal/command/commandargs/command_args.go b/internal/command/commandargs/command_args.go
index a01b8b2..f23ba18 100644
--- a/internal/command/commandargs/command_args.go
+++ b/internal/command/commandargs/command_args.go
@@ -1,37 +1,8 @@
package commandargs
-import (
- "errors"
- "fmt"
-
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
-)
-
type CommandType string
type CommandArgs interface {
Parse() error
GetArguments() []string
}
-
-func Parse(e *executable.Executable, arguments []string, env sshenv.Env) (CommandArgs, error) {
- var args CommandArgs
-
- switch e.Name {
- case executable.GitlabShell:
- args = &Shell{Arguments: arguments, Env: env}
- case executable.AuthorizedKeysCheck:
- args = &AuthorizedKeys{Arguments: arguments}
- case executable.AuthorizedPrincipalsCheck:
- args = &AuthorizedPrincipals{Arguments: arguments}
- default:
- return nil, errors.New(fmt.Sprintf("unknown executable: %s", e.Name))
- }
-
- if err := args.Parse(); err != nil {
- return nil, err
- }
-
- return args, nil
-}
diff --git a/internal/command/commandargs/command_args_test.go b/internal/command/commandargs/command_args_test.go
deleted file mode 100644
index 119ecd4..0000000
--- a/internal/command/commandargs/command_args_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package commandargs
-
-import (
- "testing"
-
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestParseSuccess(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- arguments []string
- expectedArgs CommandArgs
- expectError bool
- }{
- {
- desc: "It sets discover as the command when the command string was empty",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: Discover, Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It finds the key id in any passed arguments",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{"hello", "key-123"},
- expectedArgs: &Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It finds the key id only if the argument is of <key-id> format",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{"hello", "username-key-123"},
- expectedArgs: &Shell{Arguments: []string{"hello", "username-key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "key-123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It finds the username in any passed arguments",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{"hello", "username-jane-doe"},
- expectedArgs: &Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It parses 2fa_recovery_codes command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}},
- }, {
- desc: "It parses git-receive-pack command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}},
- }, {
- desc: "It parses git-receive-pack command and a project with single quotes",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}},
- }, {
- desc: `It parses "git receive-pack" command`,
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}},
- }, {
- desc: `It parses a command followed by control characters`,
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}},
- }, {
- desc: "It parses git-upload-pack command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: UploadPack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}},
- }, {
- desc: "It parses git-upload-archive command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: UploadArchive, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}},
- }, {
- desc: "It parses git-lfs-authenticate command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: LfsAuthenticate, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}},
- }, {
- desc: "It parses authorized-keys command",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"git", "git", "key"},
- expectedArgs: &AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"},
- }, {
- desc: "It parses authorized-principals command",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"key", "principal-1", "principal-2"},
- expectedArgs: &AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}},
- }, {
- desc: "Unknown executable",
- executable: &executable.Executable{Name: "unknown"},
- arguments: []string{},
- expectError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- result, err := Parse(tc.executable, tc.arguments, tc.env)
-
- if !tc.expectError {
- require.NoError(t, err)
- require.Equal(t, tc.expectedArgs, result)
- } else {
- require.Error(t, err)
- }
- })
- }
-}
-
-func TestParseFailure(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- arguments []string
- expectedError string
- }{
- {
- desc: "It fails if SSH connection is not set",
- executable: &executable.Executable{Name: executable.GitlabShell},
- arguments: []string{},
- expectedError: "Only SSH allowed",
- },
- {
- desc: "It fails if SSH command is invalid",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git receive-pack "`},
- arguments: []string{},
- expectedError: "Invalid SSH command",
- },
- {
- desc: "With not enough arguments for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user"},
- expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
- },
- {
- desc: "With too many arguments for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user", "user", "key", "something-else"},
- expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
- },
- {
- desc: "With missing username for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user", "", "key"},
- expectedError: "# No username provided",
- },
- {
- desc: "With missing key for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user", "user", ""},
- expectedError: "# No key provided",
- },
- {
- desc: "With not enough arguments for the AuthorizedPrincipalsCheck",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"key"},
- expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]",
- },
- {
- desc: "With missing key_id for the AuthorizedPrincipalsCheck",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"", "principal"},
- expectedError: "# No key_id provided",
- },
- {
- desc: "With blank principal for the AuthorizedPrincipalsCheck",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"key", "principal", ""},
- expectedError: "# An invalid principal was provided",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- _, err := Parse(tc.executable, tc.arguments, tc.env)
-
- require.EqualError(t, err, tc.expectedError)
- })
- }
-}
diff --git a/internal/command/commandargs/shell.go b/internal/command/commandargs/shell.go
index 589f58d..7a76be5 100644
--- a/internal/command/commandargs/shell.go
+++ b/internal/command/commandargs/shell.go
@@ -1,7 +1,7 @@
package commandargs
import (
- "errors"
+ "fmt"
"regexp"
"github.com/mattn/go-shellwords"
@@ -49,21 +49,16 @@ func (s *Shell) GetArguments() []string {
func (s *Shell) validate() error {
if !s.Env.IsSSHConnection {
- return errors.New("Only SSH allowed")
+ return fmt.Errorf("Only SSH allowed")
}
- if !s.isValidSSHCommand() {
- return errors.New("Invalid SSH command")
+ if err := s.ParseCommand(s.Env.OriginalCommand); err != nil {
+ return fmt.Errorf("Invalid SSH command: %w", err)
}
return nil
}
-func (s *Shell) isValidSSHCommand() bool {
- err := s.ParseCommand(s.Env.OriginalCommand)
- return err == nil
-}
-
func (s *Shell) parseWho() {
for _, argument := range s.Arguments {
if keyId := tryParseKeyId(argument); keyId != "" {
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
index dab69ab..ac3aafc 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -6,6 +6,8 @@ import (
"encoding/json"
"fmt"
+ "gitlab.com/gitlab-org/labkit/log"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -58,6 +60,11 @@ func (c *Command) Execute(ctx context.Context) error {
payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
+ log.WithContextFields(
+ ctx,
+ log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId},
+ ).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed")
+
return nil
}
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 34086fb..73d2ce4 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -64,7 +64,7 @@ func (c *Command) processApiEndpoints(ctx context.Context, response *accessverif
"endpoint": endpoint,
}
- log.WithFields(fields).Info("Performing custom action")
+ log.WithContextFields(ctx, fields).Info("Performing custom action")
response, err := c.performRequest(ctx, client, endpoint, request)
if err != nil {
diff --git a/internal/config/config.go b/internal/config/config.go
index c67f60a..5185736 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -46,6 +46,7 @@ type Config struct {
RootDir string
LogFile string `yaml:"log_file,omitempty"`
LogFormat string `yaml:"log_format,omitempty"`
+ LogLevel string `yaml:"log_level,omitempty"`
GitlabUrl string `yaml:"gitlab_url"`
GitlabRelativeURLRoot string `yaml:"gitlab_relative_url_root"`
GitlabTracing string `yaml:"gitlab_tracing"`
@@ -66,6 +67,7 @@ var (
DefaultConfig = Config{
LogFile: "gitlab-shell.log",
LogFormat: "json",
+ LogLevel: "info",
Server: DefaultServerConfig,
User: "git",
}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 699a261..78b2ed4 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -32,7 +32,7 @@ func TestHttpClient(t *testing.T) {
client, err := config.HttpClient()
require.NoError(t, err)
- _, err = client.Get("http://host.com/path")
+ _, err = client.Get(url)
require.NoError(t, err)
ms, err := prometheus.DefaultGatherer.Gather()
diff --git a/internal/executable/executable.go b/internal/executable/executable.go
index 8b6b586..c6355b9 100644
--- a/internal/executable/executable.go
+++ b/internal/executable/executable.go
@@ -14,9 +14,8 @@ const (
)
type Executable struct {
- Name string
- RootDir string
- AcceptArgs bool
+ Name string
+ RootDir string
}
var (
@@ -24,7 +23,7 @@ var (
osExecutable = os.Executable
)
-func New(name string, acceptArgs bool) (*Executable, error) {
+func New(name string) (*Executable, error) {
path, err := osExecutable()
if err != nil {
return nil, err
@@ -36,9 +35,8 @@ func New(name string, acceptArgs bool) (*Executable, error) {
}
executable := &Executable{
- Name: name,
- RootDir: rootDir,
- AcceptArgs: acceptArgs,
+ Name: name,
+ RootDir: rootDir,
}
return executable, nil
diff --git a/internal/executable/executable_test.go b/internal/executable/executable_test.go
index 71984c3..3915f1a 100644
--- a/internal/executable/executable_test.go
+++ b/internal/executable/executable_test.go
@@ -59,7 +59,7 @@ func TestNewSuccess(t *testing.T) {
fake.Setup()
defer fake.Cleanup()
- result, err := New("gitlab-shell", true)
+ result, err := New("gitlab-shell")
require.NoError(t, err)
require.Equal(t, result.Name, "gitlab-shell")
@@ -96,7 +96,7 @@ func TestNewFailure(t *testing.T) {
fake.Setup()
defer fake.Cleanup()
- _, err := New("gitlab-shell", true)
+ _, err := New("gitlab-shell")
require.Error(t, err)
})
diff --git a/internal/handler/exec.go b/internal/handler/exec.go
index 27031b1..172736d 100644
--- a/internal/handler/exec.go
+++ b/internal/handler/exec.go
@@ -9,7 +9,9 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"google.golang.org/grpc"
+ grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ grpcstatus "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
@@ -50,6 +52,12 @@ func (gc *GitalyCommand) RunGitalyCommand(ctx context.Context, handler GitalyHan
childCtx := withOutgoingMetadata(ctx, gc.Features)
_, err = handler(childCtx, conn)
+ if err != nil && grpcstatus.Convert(err).Code() == grpccodes.Unavailable {
+ log.WithError(err).Error("Gitaly is unavailable")
+
+ return fmt.Errorf("Git service is temporarily unavailable")
+ }
+
return err
}
@@ -110,7 +118,7 @@ func getConn(ctx context.Context, gc *GitalyCommand) (*grpc.ClientConn, error) {
if serviceName == "" {
serviceName = "gitlab-shell-unknown"
- log.WithFields(log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown")
+ log.WithContextFields(ctx, log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown")
}
serviceName = fmt.Sprintf("%s-%s", serviceName, gc.ServiceName)
diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go
index 5ad0675..1d714ef 100644
--- a/internal/handler/exec_test.go
+++ b/internal/handler/exec_test.go
@@ -7,7 +7,9 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
+ grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ grpcstatus "google.golang.org/grpc/status"
pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
@@ -45,6 +47,18 @@ func TestMissingGitalyAddress(t *testing.T) {
require.EqualError(t, err, "no gitaly_address given")
}
+func TestUnavailableGitalyErr(t *testing.T) {
+ cmd := GitalyCommand{
+ Config: &config.Config{},
+ Address: "tcp://localhost:9999",
+ }
+
+ expectedErr := grpcstatus.Error(grpccodes.Unavailable, "error")
+ err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, expectedErr))
+
+ require.EqualError(t, err, "Git service is temporarily unavailable")
+}
+
func TestRunGitalyCommandMetadata(t *testing.T) {
tests := []struct {
name string
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
index 1165680..748fce0 100644
--- a/internal/logger/logger.go
+++ b/internal/logger/logger.go
@@ -35,6 +35,7 @@ func buildOpts(cfg *config.Config) []log.LoggerOption {
log.WithFormatter(logFmt(cfg.LogFormat)),
log.WithOutputName(logFile(cfg.LogFile)),
log.WithTimezone(time.UTC),
+ log.WithLogLevel(cfg.LogLevel),
}
}
diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go
index bda36d9..4ea8c1f 100644
--- a/internal/logger/logger_test.go
+++ b/internal/logger/logger_test.go
@@ -3,7 +3,6 @@ package logger
import (
"os"
"regexp"
- "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -26,12 +25,37 @@ func TestConfigure(t *testing.T) {
defer closer.Close()
log.Info("this is a test")
+ log.WithFields(log.Fields{}).Debug("debug log message")
tmpFile.Close()
data, err := os.ReadFile(tmpFile.Name())
require.NoError(t, err)
- require.True(t, strings.Contains(string(data), `msg":"this is a test"`))
+ require.Contains(t, string(data), `msg":"this is a test"`)
+ require.NotContains(t, string(data), `msg:":"debug log message"`)
+}
+
+func TestConfigureWithDebugLogLevel(t *testing.T) {
+ tmpFile, err := os.CreateTemp(os.TempDir(), "logtest-")
+ require.NoError(t, err)
+ defer tmpFile.Close()
+
+ config := config.Config{
+ LogFile: tmpFile.Name(),
+ LogFormat: "json",
+ LogLevel: "debug",
+ }
+
+ closer := Configure(&config)
+ defer closer.Close()
+
+ log.WithFields(log.Fields{}).Debug("debug log message")
+
+ tmpFile.Close()
+
+ data, err := os.ReadFile(tmpFile.Name())
+ require.NoError(t, err)
+ require.Contains(t, string(data), `msg":"debug log message"`)
}
func TestConfigureWithPermissionError(t *testing.T) {
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 0e0da93..f6d8fb5 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -29,21 +29,26 @@ func newConnection(maxSessions int64, remoteAddr string) *connection {
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+
defer metrics.SshdConnectionDuration.Observe(time.Since(c.begin).Seconds())
for newChannel := range chans {
+ ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
+ ctxlog.Info("connection: handle: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
+ ctxlog.Info("connection: handle: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
- log.WithError(err).Info("could not accept channel")
+ ctxlog.WithError(err).Error("connection: handle: accepting channel failed")
c.concurrentSessions.Release(1)
continue
}
@@ -54,11 +59,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
// Prevent a panic in a single session from taking out the whole server
defer func() {
if err := recover(); err != nil {
- log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", c.remoteAddr)
+ ctxlog.WithField("recovered_error", err).Warn("panic handling session")
}
}()
handler(ctx, channel, requests)
+ ctxlog.Info("connection: handle: done")
}()
}
}
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go
new file mode 100644
index 0000000..68210f8
--- /dev/null
+++ b/internal/sshd/server_config.go
@@ -0,0 +1,94 @@
+package sshd
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "os"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+type serverConfig struct {
+ cfg *config.Config
+ hostKeys []ssh.Signer
+ authorizedKeysClient *authorizedkeys.Client
+}
+
+func newServerConfig(cfg *config.Config) (*serverConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ var hostKeys []ssh.Signer
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := os.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to read host key")
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to parse host key")
+ continue
+ }
+
+ hostKeys = append(hostKeys, key)
+ }
+ if len(hostKeys) == 0 {
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+}
+
+func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) {
+ if user != s.cfg.User {
+ return nil, fmt.Errorf("unknown user")
+ }
+ if key.Type() == ssh.KeyAlgoDSA {
+ return nil, fmt.Errorf("DSA is prohibited")
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig {
+ sshCfg := &ssh.ServerConfig{
+ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ res, err := s.getAuthKey(ctx, conn.User(), key)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "key-id": strconv.FormatInt(res.Id, 10),
+ },
+ }, nil
+ },
+ }
+
+ for _, key := range s.hostKeys {
+ sshCfg.AddHostKey(key)
+ }
+
+ return sshCfg
+}
diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go
new file mode 100644
index 0000000..58bd3e1
--- /dev/null
+++ b/internal/sshd/server_config_test.go
@@ -0,0 +1,105 @@
+package sshd
+
+import (
+ "context"
+ "crypto/dsa"
+ "crypto/rand"
+ "crypto/rsa"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestNewServerConfigWithoutHosts(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "No host keys could be loaded, aborting", err.Error())
+}
+
+func TestFailedAuthorizedKeysClient(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error())
+}
+
+func TestFailedGetAuthKey(t *testing.T) {
+ testhelper.PrepareTestRootDir(t)
+
+ srvCfg := config.ServerConfig{
+ Listen: "127.0.0.1",
+ ConcurrentSessionsLimit: 1,
+ HostKeyFiles: []string{
+ path.Join(testhelper.TestRoot, "certs/valid/server.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid-path.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ },
+ }
+
+ cfg, err := newServerConfig(
+ &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg},
+ )
+ require.NoError(t, err)
+
+ testCases := []struct {
+ desc string
+ user string
+ key ssh.PublicKey
+ expectedError string
+ }{
+ {
+ desc: "wrong user",
+ user: "wrong-user",
+ key: rsaPublicKey(t),
+ expectedError: "unknown user",
+ }, {
+ desc: "prohibited dsa key",
+ user: "user",
+ key: dsaPublicKey(t),
+ expectedError: "DSA is prohibited",
+ }, {
+ desc: "API error",
+ user: "user",
+ key: rsaPublicKey(t),
+ expectedError: "Internal API unreachable",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key)
+ require.Error(t, err)
+ require.Equal(t, tc.expectedError, err.Error())
+ })
+ }
+}
+
+func rsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
+
+func dsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey := new(dsa.PrivateKey)
+ params := new(dsa.Parameters)
+ require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160))
+
+ privateKey.PublicKey.Parameters = *params
+ require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader))
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index 22cb715..b8e8625 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -2,13 +2,16 @@ package sshd
import (
"context"
+ "errors"
"fmt"
+ "reflect"
+ "gitlab.com/gitlab-org/labkit/log"
"golang.org/x/crypto/ssh"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
@@ -39,16 +42,27 @@ type exitStatusReq struct {
}
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+ ctxlog := log.ContextLogger(ctx)
+
+ ctxlog.Debug("session: handle: entering request loop")
+
for req := range requests {
+ sessionLog := ctxlog.WithFields(log.Fields{
+ "bytesize": len(req.Payload),
+ "type": req.Type,
+ "want_reply": req.WantReply,
+ })
+ sessionLog.Debug("session: handle: request received")
+
var shouldContinue bool
switch req.Type {
case "env":
- shouldContinue = s.handleEnv(req)
+ shouldContinue = s.handleEnv(ctx, req)
case "exec":
shouldContinue = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
- s.exit(s.handleShell(ctx, req))
+ s.exit(ctx, s.handleShell(ctx, req))
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
@@ -57,18 +71,23 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
}
+ sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed")
+
if !shouldContinue {
s.channel.Close()
break
}
}
+
+ ctxlog.Debug("session: handle: exiting request loop")
}
-func (s *session) handleEnv(req *ssh.Request) bool {
+func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
+ log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
return false
}
@@ -84,6 +103,10 @@ func (s *session) handleEnv(req *ssh.Request) bool {
req.Reply(accepted, []byte{})
}
+ log.WithContextFields(
+ ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
+ ).Debug("session: handleEnv: processed")
+
return true
}
@@ -95,7 +118,8 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
s.execCmd = execRequest.Command
- s.exit(s.handleShell(ctx, req))
+ s.exit(ctx, s.handleShell(ctx, req))
+
return false
}
@@ -104,19 +128,11 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
req.Reply(true, []byte{})
}
- args := &commandargs.Shell{
- GitlabKeyId: s.gitlabKeyId,
- Env: sshenv.Env{
- IsSSHConnection: true,
- OriginalCommand: s.execCmd,
- GitProtocolVersion: s.gitProtocolVersion,
- RemoteAddr: s.remoteAddr,
- },
- }
-
- if err := args.ParseCommand(s.execCmd); err != nil {
- s.toStderr("Failed to parse command: %v\n", err.Error())
- return 128
+ env := sshenv.Env{
+ IsSSHConnection: true,
+ OriginalCommand: s.execCmd,
+ GitProtocolVersion: s.gitProtocolVersion,
+ RemoteAddr: s.remoteAddr,
}
rw := &readwriter.ReadWriter{
@@ -125,25 +141,37 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
ErrOut: s.channel.Stderr(),
}
- cmd := command.BuildShellCommand(args, s.cfg, rw)
- if cmd == nil {
- s.toStderr("Unknown command: %v\n", args.CommandType)
+ cmd, err := shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
+ if err != nil {
+ if !errors.Is(err, disallowedcommand.Error) {
+ s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
+ }
+ s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
return 128
}
+ cmdName := reflect.TypeOf(cmd).String()
+ ctxlog := log.ContextLogger(ctx)
+ ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command")
+
if err := cmd.Execute(ctx); err != nil {
- s.toStderr("remote: ERROR: %v\n", err.Error())
+ s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
return 1
}
+ ctxlog.Info("session: handleShell: command executed successfully")
+
return 0
}
-func (s *session) toStderr(format string, args ...interface{}) {
- fmt.Fprintf(s.channel.Stderr(), format, args...)
+func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
+ out := fmt.Sprintf(format, args...)
+ log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output")
+ fmt.Fprint(s.channel.Stderr(), out)
}
-func (s *session) exit(status uint32) {
+func (s *session) exit(ctx context.Context, status uint32) {
+ log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
s.channel.CloseWrite()
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
new file mode 100644
index 0000000..f135825
--- /dev/null
+++ b/internal/sshd/session_test.go
@@ -0,0 +1,189 @@
+package sshd
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type fakeChannel struct {
+ stdErr io.ReadWriter
+ sentRequestName string
+ sentRequestPayload []byte
+}
+
+func (f *fakeChannel) Read(data []byte) (int, error) {
+ return 0, nil
+}
+
+func (f *fakeChannel) Write(data []byte) (int, error) {
+ return 0, nil
+}
+
+func (f *fakeChannel) Close() error {
+ return nil
+}
+
+func (f *fakeChannel) CloseWrite() error {
+ return nil
+}
+
+func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ f.sentRequestName = name
+ f.sentRequestPayload = payload
+
+ return true, nil
+}
+
+func (f *fakeChannel) Stderr() io.ReadWriter {
+ return f.stdErr
+}
+
+var requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`))
+ },
+ },
+}
+
+func TestHandleEnv(t *testing.T) {
+ testCases := []struct {
+ desc string
+ payload []byte
+ expectedProtocolVersion string
+ expectedResult bool
+ }{
+ {
+ desc: "invalid payload",
+ payload: []byte("invalid"),
+ expectedProtocolVersion: "1",
+ expectedResult: false,
+ }, {
+ desc: "valid payload",
+ payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedProtocolVersion: "2",
+ expectedResult: true,
+ }, {
+ desc: "valid payload with forbidden env var",
+ payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedProtocolVersion: "1",
+ expectedResult: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ s := &session{gitProtocolVersion: "1"}
+ r := &ssh.Request{Payload: tc.payload}
+
+ require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
+ require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ })
+ }
+}
+
+func TestHandleExec(t *testing.T) {
+ testCases := []struct {
+ desc string
+ payload []byte
+ expectedExecCmd string
+ sentRequestName string
+ sentRequestPayload []byte
+ }{
+ {
+ desc: "invalid payload",
+ payload: []byte("invalid"),
+ expectedExecCmd: "",
+ sentRequestName: "",
+ }, {
+ desc: "valid payload",
+ payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedExecCmd: "discover",
+ sentRequestName: "exit-status",
+ sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
+ },
+ }
+
+ url := testserver.StartHttpServer(t, requests)
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ f := &fakeChannel{stdErr: out}
+ s := &session{
+ gitlabKeyId: "root",
+ channel: f,
+ cfg: &config.Config{GitlabUrl: url},
+ }
+ r := &ssh.Request{Payload: tc.payload}
+
+ require.Equal(t, false, s.handleExec(context.Background(), r))
+ require.Equal(t, tc.sentRequestName, f.sentRequestName)
+ require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
+ })
+ }
+}
+
+func TestHandleShell(t *testing.T) {
+ testCases := []struct {
+ desc string
+ cmd string
+ errMsg string
+ gitlabKeyId string
+ expectedExitCode uint32
+ }{
+ {
+ desc: "fails to parse command",
+ cmd: `\`,
+ errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
+ gitlabKeyId: "root",
+ expectedExitCode: 128,
+ }, {
+ desc: "specified command is unknown",
+ cmd: "unknown-command",
+ errMsg: "Unknown command: unknown-command\n",
+ gitlabKeyId: "root",
+ expectedExitCode: 128,
+ }, {
+ desc: "fails to parse command",
+ cmd: "discover",
+ gitlabKeyId: "",
+ errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedExitCode: 1,
+ }, {
+ desc: "fails to parse command",
+ cmd: "discover",
+ errMsg: "",
+ gitlabKeyId: "root",
+ expectedExitCode: 0,
+ },
+ }
+
+ url := testserver.StartHttpServer(t, requests)
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ s := &session{
+ gitlabKeyId: tc.gitlabKeyId,
+ execCmd: tc.cmd,
+ channel: &fakeChannel{stdErr: out},
+ cfg: &config.Config{GitlabUrl: url},
+ }
+ r := &ssh.Request{}
+
+ require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ require.Equal(t, tc.errMsg, out.String())
+ })
+ }
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index de5fbd4..19fa661 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -2,13 +2,9 @@ package sshd
import (
"context"
- "encoding/base64"
- "errors"
"fmt"
"net"
"net/http"
- "os"
- "strconv"
"sync"
"time"
@@ -16,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
@@ -35,44 +30,24 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
- hostKeys []ssh.Signer
- authorizedKeysClient *authorizedkeys.Client
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ serverConfig *serverConfig
}
func NewServer(cfg *config.Config) (*Server, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ serverConfig, err := newServerConfig(cfg)
if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ return nil, err
}
- var hostKeys []ssh.Signer
- for _, filename := range cfg.Server.HostKeyFiles {
- keyRaw, err := os.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
-
- hostKeys = append(hostKeys, key)
- }
- if len(hostKeys) == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
-
- return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+ return &Server{Config: cfg, serverConfig: serverConfig}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
- if err := s.listen(); err != nil {
+ if err := s.listen(ctx); err != nil {
return err
}
defer s.listener.Close()
@@ -110,7 +85,7 @@ func (s *Server) MonitoringServeMux() *http.ServeMux {
return mux
}
-func (s *Server) listen() error {
+func (s *Server) listen(ctx context.Context) error {
sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
return fmt.Errorf("failed to listen for connection: %w", err)
@@ -122,10 +97,10 @@ func (s *Server) listen() error {
ReadHeaderTimeout: ProxyHeaderTimeout,
}
- log.Info("Proxy protocol is enabled")
+ log.ContextLogger(ctx).Info("Proxy protocol is enabled")
}
- log.WithFields(log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections")
+ log.WithContextFields(ctx, log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections")
s.listener = sshListener
@@ -142,7 +117,7 @@ func (s *Server) serve(ctx context.Context) {
break
}
- log.WithError(err).Warn("Failed to accept connection")
+ log.ContextLogger(ctx).WithError(err).Warn("Failed to accept connection")
continue
}
@@ -168,57 +143,29 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
- sshCfg := &ssh.ServerConfig{
- PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != s.Config.User {
- return nil, errors.New("unknown user")
- }
- if key.Type() == ssh.KeyAlgoDSA {
- return nil, errors.New("DSA is prohibited")
- }
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
- if err != nil {
- return nil, err
- }
-
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "key-id": strconv.FormatInt(res.Id, 10),
- },
- }, nil
- },
- }
-
- for _, key := range s.hostKeys {
- sshCfg.AddHostKey(key)
- }
-
- return sshCfg
-}
-
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
defer nconn.Close()
+ ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
+ defer cancel()
+
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
+
// Prevent a panic in a single connection from taking out the whole server
defer func() {
if err := recover(); err != nil {
- log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", remoteAddr)
+ ctxlog.Warn("panic handling session")
}
}()
- ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
- defer cancel()
+ ctxlog.Info("server: handleConn: start")
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
- log.WithError(err).Info("Failed to initialize SSH connection")
+ ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection")
return
}
@@ -235,4 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
session.handle(ctx, requests)
})
+
+ ctxlog.Info("server: handleConn: done")
}
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 32946af..71f7733 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -104,6 +104,24 @@ func TestLivenessProbe(t *testing.T) {
require.Equal(t, 200, r.Result().StatusCode)
}
+func TestInvalidClientConfig(t *testing.T) {
+ setupServer(t)
+
+ cfg := clientConfig(t)
+ cfg.User = "unknown"
+ _, err := ssh.Dial("tcp", serverUrl, cfg)
+ require.Error(t, err)
+}
+
+func TestInvalidServerConfig(t *testing.T) {
+ s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}}
+ err := s.ListenAndServe(context.Background())
+
+ require.Error(t, err)
+ require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error())
+ require.Nil(t, s.Shutdown())
+}
+
func setupServer(t *testing.T) *Server {
t.Helper()