summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/command/command.go83
-rw-r--r--internal/command/command_test.go158
-rw-r--r--internal/command/commandargs/command_args.go29
-rw-r--r--internal/command/commandargs/command_args_test.go197
-rw-r--r--internal/command/commandargs/shell.go13
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go7
-rw-r--r--internal/command/shared/customaction/customaction.go2
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/config/config_test.go2
-rw-r--r--internal/executable/executable.go12
-rw-r--r--internal/executable/executable_test.go4
-rw-r--r--internal/handler/exec.go10
-rw-r--r--internal/handler/exec_test.go14
-rw-r--r--internal/logger/logger.go1
-rw-r--r--internal/logger/logger_test.go28
-rw-r--r--internal/sshd/connection.go10
-rw-r--r--internal/sshd/server_config.go94
-rw-r--r--internal/sshd/server_config_test.go105
-rw-r--r--internal/sshd/session.go80
-rw-r--r--internal/sshd/session_test.go189
-rw-r--r--internal/sshd/sshd.go99
-rw-r--r--internal/sshd/sshd_test.go18
22 files changed, 564 insertions, 593 deletions
diff --git a/internal/command/command.go b/internal/command/command.go
index dadf41a..4ee568e 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -3,23 +3,7 @@ package command
import (
"context"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/tracing"
)
@@ -28,23 +12,6 @@ type Command interface {
Execute(ctx context.Context) error
}
-func New(e *executable.Executable, arguments []string, env sshenv.Env, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
- var args commandargs.CommandArgs
- if e.AcceptArgs {
- var err error
- args, err = commandargs.Parse(e, arguments, env)
- if err != nil {
- return nil, err
- }
- }
-
- if cmd := buildCommand(e, args, config, readWriter); cmd != nil {
- return cmd, nil
- }
-
- return nil, disallowedcommand.Error
-}
-
// Setup() initializes tracing from the configuration file and generates a
// background context from which all other contexts in the process should derive
// from, as it has a service name and initial correlation ID set.
@@ -80,53 +47,3 @@ func Setup(serviceName string, config *config.Config) (context.Context, func())
closer.Close()
}
}
-
-func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- switch e.Name {
- case executable.GitlabShell:
- return BuildShellCommand(args.(*commandargs.Shell), config, readWriter)
- case executable.AuthorizedKeysCheck:
- return buildAuthorizedKeysCommand(args.(*commandargs.AuthorizedKeys), config, readWriter)
- case executable.AuthorizedPrincipalsCheck:
- return buildAuthorizedPrincipalsCommand(args.(*commandargs.AuthorizedPrincipals), config, readWriter)
- case executable.Healthcheck:
- return buildHealthcheckCommand(config, readWriter)
- }
-
- return nil
-}
-
-func BuildShellCommand(args *commandargs.Shell, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- switch args.CommandType {
- case commandargs.Discover:
- return &discover.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.TwoFactorRecover:
- return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.TwoFactorVerify:
- return &twofactorverify.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.LfsAuthenticate:
- return &lfsauthenticate.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.ReceivePack:
- return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.UploadPack:
- return &uploadpack.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.UploadArchive:
- return &uploadarchive.Command{Config: config, Args: args, ReadWriter: readWriter}
- case commandargs.PersonalAccessToken:
- return &personalaccesstoken.Command{Config: config, Args: args, ReadWriter: readWriter}
- }
-
- return nil
-}
-
-func buildAuthorizedKeysCommand(args *commandargs.AuthorizedKeys, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- return &authorizedkeys.Command{Config: config, Args: args, ReadWriter: readWriter}
-}
-
-func buildAuthorizedPrincipalsCommand(args *commandargs.AuthorizedPrincipals, config *config.Config, readWriter *readwriter.ReadWriter) Command {
- return &authorizedprincipals.Command{Config: config, Args: args, ReadWriter: readWriter}
-}
-
-func buildHealthcheckCommand(config *config.Config, readWriter *readwriter.ReadWriter) Command {
- return &healthcheck.Command{Config: config, ReadWriter: readWriter}
-}
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index a538745..2fc6655 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -1,173 +1,15 @@
package command
import (
- "errors"
"os"
"testing"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/discover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/healthcheck"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/lfsauthenticate"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/personalaccesstoken"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/receivepack"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorrecover"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/twofactorverify"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadarchive"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
"gitlab.com/gitlab-org/labkit/correlation"
)
-var (
- authorizedKeysExec = &executable.Executable{Name: executable.AuthorizedKeysCheck, AcceptArgs: true}
- authorizedPrincipalsExec = &executable.Executable{Name: executable.AuthorizedPrincipalsCheck, AcceptArgs: true}
- checkExec = &executable.Executable{Name: executable.Healthcheck, AcceptArgs: false}
- gitlabShellExec = &executable.Executable{Name: executable.GitlabShell, AcceptArgs: true}
-
- basicConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket"}
- advancedConfig = &config.Config{GitlabUrl: "http+unix://gitlab.socket", SslCertDir: "/tmp/certs"}
-)
-
-func buildEnv(command string) sshenv.Env {
- return sshenv.Env{
- IsSSHConnection: true,
- OriginalCommand: command,
- }
-}
-
-func TestNew(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- arguments []string
- config *config.Config
- expectedType interface{}
- }{
- {
- desc: "it returns a Discover command",
- executable: gitlabShellExec,
- env: buildEnv(""),
- config: basicConfig,
- expectedType: &discover.Command{},
- },
- {
- desc: "it returns a TwoFactorRecover command",
- executable: gitlabShellExec,
- env: buildEnv("2fa_recovery_codes"),
- config: basicConfig,
- expectedType: &twofactorrecover.Command{},
- },
- {
- desc: "it returns a TwoFactorVerify command",
- executable: gitlabShellExec,
- env: buildEnv("2fa_verify"),
- config: basicConfig,
- expectedType: &twofactorverify.Command{},
- },
- {
- desc: "it returns an LfsAuthenticate command",
- executable: gitlabShellExec,
- env: buildEnv("git-lfs-authenticate"),
- config: basicConfig,
- expectedType: &lfsauthenticate.Command{},
- },
- {
- desc: "it returns a ReceivePack command",
- executable: gitlabShellExec,
- env: buildEnv("git-receive-pack"),
- config: basicConfig,
- expectedType: &receivepack.Command{},
- },
- {
- desc: "it returns an UploadPack command",
- executable: gitlabShellExec,
- env: buildEnv("git-upload-pack"),
- config: basicConfig,
- expectedType: &uploadpack.Command{},
- },
- {
- desc: "it returns an UploadArchive command",
- executable: gitlabShellExec,
- env: buildEnv("git-upload-archive"),
- config: basicConfig,
- expectedType: &uploadarchive.Command{},
- },
- {
- desc: "it returns a Healthcheck command",
- executable: checkExec,
- config: basicConfig,
- expectedType: &healthcheck.Command{},
- },
- {
- desc: "it returns a AuthorizedKeys command",
- executable: authorizedKeysExec,
- arguments: []string{"git", "git", "key"},
- config: basicConfig,
- expectedType: &authorizedkeys.Command{},
- },
- {
- desc: "it returns a AuthorizedPrincipals command",
- executable: authorizedPrincipalsExec,
- arguments: []string{"key", "principal"},
- config: basicConfig,
- expectedType: &authorizedprincipals.Command{},
- },
- {
- desc: "it returns a PersonalAccessToken command",
- executable: gitlabShellExec,
- env: buildEnv("personal_access_token"),
- config: basicConfig,
- expectedType: &personalaccesstoken.Command{},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- command, err := New(tc.executable, tc.arguments, tc.env, tc.config, nil)
-
- require.NoError(t, err)
- require.IsType(t, tc.expectedType, command)
- })
- }
-}
-
-func TestFailingNew(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- expectedError error
- }{
- {
- desc: "Parsing environment failed",
- executable: gitlabShellExec,
- expectedError: errors.New("Only SSH allowed"),
- },
- {
- desc: "Unknown command given",
- executable: gitlabShellExec,
- env: buildEnv("unknown"),
- expectedError: disallowedcommand.Error,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- command, err := New(tc.executable, []string{}, tc.env, basicConfig, nil)
- require.Nil(t, command)
- require.Equal(t, tc.expectedError, err)
- })
- }
-}
-
func TestSetup(t *testing.T) {
testCases := []struct {
name string
diff --git a/internal/command/commandargs/command_args.go b/internal/command/commandargs/command_args.go
index a01b8b2..f23ba18 100644
--- a/internal/command/commandargs/command_args.go
+++ b/internal/command/commandargs/command_args.go
@@ -1,37 +1,8 @@
package commandargs
-import (
- "errors"
- "fmt"
-
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
-)
-
type CommandType string
type CommandArgs interface {
Parse() error
GetArguments() []string
}
-
-func Parse(e *executable.Executable, arguments []string, env sshenv.Env) (CommandArgs, error) {
- var args CommandArgs
-
- switch e.Name {
- case executable.GitlabShell:
- args = &Shell{Arguments: arguments, Env: env}
- case executable.AuthorizedKeysCheck:
- args = &AuthorizedKeys{Arguments: arguments}
- case executable.AuthorizedPrincipalsCheck:
- args = &AuthorizedPrincipals{Arguments: arguments}
- default:
- return nil, errors.New(fmt.Sprintf("unknown executable: %s", e.Name))
- }
-
- if err := args.Parse(); err != nil {
- return nil, err
- }
-
- return args, nil
-}
diff --git a/internal/command/commandargs/command_args_test.go b/internal/command/commandargs/command_args_test.go
deleted file mode 100644
index 119ecd4..0000000
--- a/internal/command/commandargs/command_args_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package commandargs
-
-import (
- "testing"
-
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
- "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestParseSuccess(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- arguments []string
- expectedArgs CommandArgs
- expectError bool
- }{
- {
- desc: "It sets discover as the command when the command string was empty",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{}, CommandType: Discover, Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It finds the key id in any passed arguments",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{"hello", "key-123"},
- expectedArgs: &Shell{Arguments: []string{"hello", "key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It finds the key id only if the argument is of <key-id> format",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{"hello", "username-key-123"},
- expectedArgs: &Shell{Arguments: []string{"hello", "username-key-123"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "key-123", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It finds the username in any passed arguments",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"},
- arguments: []string{"hello", "username-jane-doe"},
- expectedArgs: &Shell{Arguments: []string{"hello", "username-jane-doe"}, SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe", Env: sshenv.Env{IsSSHConnection: true, RemoteAddr: "1"}},
- }, {
- desc: "It parses 2fa_recovery_codes command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "2fa_recovery_codes"}},
- }, {
- desc: "It parses git-receive-pack command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack group/repo"}},
- }, {
- desc: "It parses git-receive-pack command and a project with single quotes",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-receive-pack 'group/repo'"}},
- }, {
- desc: `It parses "git receive-pack" command`,
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack "group/repo"`}},
- }, {
- desc: `It parses a command followed by control characters`,
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git-receive-pack group/repo; any command`}},
- }, {
- desc: "It parses git-upload-pack command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-pack", "group/repo"}, CommandType: UploadPack, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git upload-pack "group/repo"`}},
- }, {
- desc: "It parses git-upload-archive command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-upload-archive", "group/repo"}, CommandType: UploadArchive, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-upload-archive 'group/repo'"}},
- }, {
- desc: "It parses git-lfs-authenticate command",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"},
- arguments: []string{},
- expectedArgs: &Shell{Arguments: []string{}, SshArgs: []string{"git-lfs-authenticate", "group/repo", "download"}, CommandType: LfsAuthenticate, Env: sshenv.Env{IsSSHConnection: true, OriginalCommand: "git-lfs-authenticate 'group/repo' download"}},
- }, {
- desc: "It parses authorized-keys command",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"git", "git", "key"},
- expectedArgs: &AuthorizedKeys{Arguments: []string{"git", "git", "key"}, ExpectedUser: "git", ActualUser: "git", Key: "key"},
- }, {
- desc: "It parses authorized-principals command",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"key", "principal-1", "principal-2"},
- expectedArgs: &AuthorizedPrincipals{Arguments: []string{"key", "principal-1", "principal-2"}, KeyId: "key", Principals: []string{"principal-1", "principal-2"}},
- }, {
- desc: "Unknown executable",
- executable: &executable.Executable{Name: "unknown"},
- arguments: []string{},
- expectError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- result, err := Parse(tc.executable, tc.arguments, tc.env)
-
- if !tc.expectError {
- require.NoError(t, err)
- require.Equal(t, tc.expectedArgs, result)
- } else {
- require.Error(t, err)
- }
- })
- }
-}
-
-func TestParseFailure(t *testing.T) {
- testCases := []struct {
- desc string
- executable *executable.Executable
- env sshenv.Env
- arguments []string
- expectedError string
- }{
- {
- desc: "It fails if SSH connection is not set",
- executable: &executable.Executable{Name: executable.GitlabShell},
- arguments: []string{},
- expectedError: "Only SSH allowed",
- },
- {
- desc: "It fails if SSH command is invalid",
- executable: &executable.Executable{Name: executable.GitlabShell},
- env: sshenv.Env{IsSSHConnection: true, OriginalCommand: `git receive-pack "`},
- arguments: []string{},
- expectedError: "Invalid SSH command",
- },
- {
- desc: "With not enough arguments for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user"},
- expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
- },
- {
- desc: "With too many arguments for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user", "user", "key", "something-else"},
- expectedError: "# Insufficient arguments. 4. Usage\n#\tgitlab-shell-authorized-keys-check <expected-username> <actual-username> <key>",
- },
- {
- desc: "With missing username for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user", "", "key"},
- expectedError: "# No username provided",
- },
- {
- desc: "With missing key for the AuthorizedKeysCheck",
- executable: &executable.Executable{Name: executable.AuthorizedKeysCheck},
- arguments: []string{"user", "user", ""},
- expectedError: "# No key provided",
- },
- {
- desc: "With not enough arguments for the AuthorizedPrincipalsCheck",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"key"},
- expectedError: "# Insufficient arguments. 1. Usage\n#\tgitlab-shell-authorized-principals-check <key-id> <principal1> [<principal2>...]",
- },
- {
- desc: "With missing key_id for the AuthorizedPrincipalsCheck",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"", "principal"},
- expectedError: "# No key_id provided",
- },
- {
- desc: "With blank principal for the AuthorizedPrincipalsCheck",
- executable: &executable.Executable{Name: executable.AuthorizedPrincipalsCheck},
- arguments: []string{"key", "principal", ""},
- expectedError: "# An invalid principal was provided",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- _, err := Parse(tc.executable, tc.arguments, tc.env)
-
- require.EqualError(t, err, tc.expectedError)
- })
- }
-}
diff --git a/internal/command/commandargs/shell.go b/internal/command/commandargs/shell.go
index 589f58d..7a76be5 100644
--- a/internal/command/commandargs/shell.go
+++ b/internal/command/commandargs/shell.go
@@ -1,7 +1,7 @@
package commandargs
import (
- "errors"
+ "fmt"
"regexp"
"github.com/mattn/go-shellwords"
@@ -49,21 +49,16 @@ func (s *Shell) GetArguments() []string {
func (s *Shell) validate() error {
if !s.Env.IsSSHConnection {
- return errors.New("Only SSH allowed")
+ return fmt.Errorf("Only SSH allowed")
}
- if !s.isValidSSHCommand() {
- return errors.New("Invalid SSH command")
+ if err := s.ParseCommand(s.Env.OriginalCommand); err != nil {
+ return fmt.Errorf("Invalid SSH command: %w", err)
}
return nil
}
-func (s *Shell) isValidSSHCommand() bool {
- err := s.ParseCommand(s.Env.OriginalCommand)
- return err == nil
-}
-
func (s *Shell) parseWho() {
for _, argument := range s.Arguments {
if keyId := tryParseKeyId(argument); keyId != "" {
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
index dab69ab..ac3aafc 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -6,6 +6,8 @@ import (
"encoding/json"
"fmt"
+ "gitlab.com/gitlab-org/labkit/log"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -58,6 +60,11 @@ func (c *Command) Execute(ctx context.Context) error {
payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
+ log.WithContextFields(
+ ctx,
+ log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId},
+ ).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed")
+
return nil
}
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 34086fb..73d2ce4 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -64,7 +64,7 @@ func (c *Command) processApiEndpoints(ctx context.Context, response *accessverif
"endpoint": endpoint,
}
- log.WithFields(fields).Info("Performing custom action")
+ log.WithContextFields(ctx, fields).Info("Performing custom action")
response, err := c.performRequest(ctx, client, endpoint, request)
if err != nil {
diff --git a/internal/config/config.go b/internal/config/config.go
index c67f60a..5185736 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -46,6 +46,7 @@ type Config struct {
RootDir string
LogFile string `yaml:"log_file,omitempty"`
LogFormat string `yaml:"log_format,omitempty"`
+ LogLevel string `yaml:"log_level,omitempty"`
GitlabUrl string `yaml:"gitlab_url"`
GitlabRelativeURLRoot string `yaml:"gitlab_relative_url_root"`
GitlabTracing string `yaml:"gitlab_tracing"`
@@ -66,6 +67,7 @@ var (
DefaultConfig = Config{
LogFile: "gitlab-shell.log",
LogFormat: "json",
+ LogLevel: "info",
Server: DefaultServerConfig,
User: "git",
}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 699a261..78b2ed4 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -32,7 +32,7 @@ func TestHttpClient(t *testing.T) {
client, err := config.HttpClient()
require.NoError(t, err)
- _, err = client.Get("http://host.com/path")
+ _, err = client.Get(url)
require.NoError(t, err)
ms, err := prometheus.DefaultGatherer.Gather()
diff --git a/internal/executable/executable.go b/internal/executable/executable.go
index 8b6b586..c6355b9 100644
--- a/internal/executable/executable.go
+++ b/internal/executable/executable.go
@@ -14,9 +14,8 @@ const (
)
type Executable struct {
- Name string
- RootDir string
- AcceptArgs bool
+ Name string
+ RootDir string
}
var (
@@ -24,7 +23,7 @@ var (
osExecutable = os.Executable
)
-func New(name string, acceptArgs bool) (*Executable, error) {
+func New(name string) (*Executable, error) {
path, err := osExecutable()
if err != nil {
return nil, err
@@ -36,9 +35,8 @@ func New(name string, acceptArgs bool) (*Executable, error) {
}
executable := &Executable{
- Name: name,
- RootDir: rootDir,
- AcceptArgs: acceptArgs,
+ Name: name,
+ RootDir: rootDir,
}
return executable, nil
diff --git a/internal/executable/executable_test.go b/internal/executable/executable_test.go
index 71984c3..3915f1a 100644
--- a/internal/executable/executable_test.go
+++ b/internal/executable/executable_test.go
@@ -59,7 +59,7 @@ func TestNewSuccess(t *testing.T) {
fake.Setup()
defer fake.Cleanup()
- result, err := New("gitlab-shell", true)
+ result, err := New("gitlab-shell")
require.NoError(t, err)
require.Equal(t, result.Name, "gitlab-shell")
@@ -96,7 +96,7 @@ func TestNewFailure(t *testing.T) {
fake.Setup()
defer fake.Cleanup()
- _, err := New("gitlab-shell", true)
+ _, err := New("gitlab-shell")
require.Error(t, err)
})
diff --git a/internal/handler/exec.go b/internal/handler/exec.go
index 27031b1..172736d 100644
--- a/internal/handler/exec.go
+++ b/internal/handler/exec.go
@@ -9,7 +9,9 @@ import (
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"google.golang.org/grpc"
+ grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ grpcstatus "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
@@ -50,6 +52,12 @@ func (gc *GitalyCommand) RunGitalyCommand(ctx context.Context, handler GitalyHan
childCtx := withOutgoingMetadata(ctx, gc.Features)
_, err = handler(childCtx, conn)
+ if err != nil && grpcstatus.Convert(err).Code() == grpccodes.Unavailable {
+ log.WithError(err).Error("Gitaly is unavailable")
+
+ return fmt.Errorf("Git service is temporarily unavailable")
+ }
+
return err
}
@@ -110,7 +118,7 @@ func getConn(ctx context.Context, gc *GitalyCommand) (*grpc.ClientConn, error) {
if serviceName == "" {
serviceName = "gitlab-shell-unknown"
- log.WithFields(log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown")
+ log.WithContextFields(ctx, log.Fields{"service_name": serviceName}).Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown")
}
serviceName = fmt.Sprintf("%s-%s", serviceName, gc.ServiceName)
diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go
index 5ad0675..1d714ef 100644
--- a/internal/handler/exec_test.go
+++ b/internal/handler/exec_test.go
@@ -7,7 +7,9 @@ import (
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
+ grpccodes "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ grpcstatus "google.golang.org/grpc/status"
pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
@@ -45,6 +47,18 @@ func TestMissingGitalyAddress(t *testing.T) {
require.EqualError(t, err, "no gitaly_address given")
}
+func TestUnavailableGitalyErr(t *testing.T) {
+ cmd := GitalyCommand{
+ Config: &config.Config{},
+ Address: "tcp://localhost:9999",
+ }
+
+ expectedErr := grpcstatus.Error(grpccodes.Unavailable, "error")
+ err := cmd.RunGitalyCommand(context.Background(), makeHandler(t, expectedErr))
+
+ require.EqualError(t, err, "Git service is temporarily unavailable")
+}
+
func TestRunGitalyCommandMetadata(t *testing.T) {
tests := []struct {
name string
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
index 1165680..748fce0 100644
--- a/internal/logger/logger.go
+++ b/internal/logger/logger.go
@@ -35,6 +35,7 @@ func buildOpts(cfg *config.Config) []log.LoggerOption {
log.WithFormatter(logFmt(cfg.LogFormat)),
log.WithOutputName(logFile(cfg.LogFile)),
log.WithTimezone(time.UTC),
+ log.WithLogLevel(cfg.LogLevel),
}
}
diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go
index bda36d9..4ea8c1f 100644
--- a/internal/logger/logger_test.go
+++ b/internal/logger/logger_test.go
@@ -3,7 +3,6 @@ package logger
import (
"os"
"regexp"
- "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -26,12 +25,37 @@ func TestConfigure(t *testing.T) {
defer closer.Close()
log.Info("this is a test")
+ log.WithFields(log.Fields{}).Debug("debug log message")
tmpFile.Close()
data, err := os.ReadFile(tmpFile.Name())
require.NoError(t, err)
- require.True(t, strings.Contains(string(data), `msg":"this is a test"`))
+ require.Contains(t, string(data), `msg":"this is a test"`)
+ require.NotContains(t, string(data), `msg:":"debug log message"`)
+}
+
+func TestConfigureWithDebugLogLevel(t *testing.T) {
+ tmpFile, err := os.CreateTemp(os.TempDir(), "logtest-")
+ require.NoError(t, err)
+ defer tmpFile.Close()
+
+ config := config.Config{
+ LogFile: tmpFile.Name(),
+ LogFormat: "json",
+ LogLevel: "debug",
+ }
+
+ closer := Configure(&config)
+ defer closer.Close()
+
+ log.WithFields(log.Fields{}).Debug("debug log message")
+
+ tmpFile.Close()
+
+ data, err := os.ReadFile(tmpFile.Name())
+ require.NoError(t, err)
+ require.Contains(t, string(data), `msg":"debug log message"`)
}
func TestConfigureWithPermissionError(t *testing.T) {
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go
index 0e0da93..f6d8fb5 100644
--- a/internal/sshd/connection.go
+++ b/internal/sshd/connection.go
@@ -29,21 +29,26 @@ func newConnection(maxSessions int64, remoteAddr string) *connection {
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
+
defer metrics.SshdConnectionDuration.Observe(time.Since(c.begin).Seconds())
for newChannel := range chans {
+ ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
+ ctxlog.Info("connection: handle: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
+ ctxlog.Info("connection: handle: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
- log.WithError(err).Info("could not accept channel")
+ ctxlog.WithError(err).Error("connection: handle: accepting channel failed")
c.concurrentSessions.Release(1)
continue
}
@@ -54,11 +59,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
// Prevent a panic in a single session from taking out the whole server
defer func() {
if err := recover(); err != nil {
- log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", c.remoteAddr)
+ ctxlog.WithField("recovered_error", err).Warn("panic handling session")
}
}()
handler(ctx, channel, requests)
+ ctxlog.Info("connection: handle: done")
}()
}
}
diff --git a/internal/sshd/server_config.go b/internal/sshd/server_config.go
new file mode 100644
index 0000000..68210f8
--- /dev/null
+++ b/internal/sshd/server_config.go
@@ -0,0 +1,94 @@
+package sshd
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "os"
+ "strconv"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+type serverConfig struct {
+ cfg *config.Config
+ hostKeys []ssh.Signer
+ authorizedKeysClient *authorizedkeys.Client
+}
+
+func newServerConfig(cfg *config.Config) (*serverConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ var hostKeys []ssh.Signer
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := os.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to read host key")
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{"filename": filename}).Warn("Failed to parse host key")
+ continue
+ }
+
+ hostKeys = append(hostKeys, key)
+ }
+ if len(hostKeys) == 0 {
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ return &serverConfig{cfg: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+}
+
+func (s *serverConfig) getAuthKey(ctx context.Context, user string, key ssh.PublicKey) (*authorizedkeys.Response, error) {
+ if user != s.cfg.User {
+ return nil, fmt.Errorf("unknown user")
+ }
+ if key.Type() == ssh.KeyAlgoDSA {
+ return nil, fmt.Errorf("DSA is prohibited")
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func (s *serverConfig) get(ctx context.Context) *ssh.ServerConfig {
+ sshCfg := &ssh.ServerConfig{
+ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
+ res, err := s.getAuthKey(ctx, conn.User(), key)
+ if err != nil {
+ return nil, err
+ }
+
+ return &ssh.Permissions{
+ // Record the public key used for authentication.
+ Extensions: map[string]string{
+ "key-id": strconv.FormatInt(res.Id, 10),
+ },
+ }, nil
+ },
+ }
+
+ for _, key := range s.hostKeys {
+ sshCfg.AddHostKey(key)
+ }
+
+ return sshCfg
+}
diff --git a/internal/sshd/server_config_test.go b/internal/sshd/server_config_test.go
new file mode 100644
index 0000000..58bd3e1
--- /dev/null
+++ b/internal/sshd/server_config_test.go
@@ -0,0 +1,105 @@
+package sshd
+
+import (
+ "context"
+ "crypto/dsa"
+ "crypto/rand"
+ "crypto/rsa"
+ "path"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+func TestNewServerConfigWithoutHosts(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "http://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "No host keys could be loaded, aborting", err.Error())
+}
+
+func TestFailedAuthorizedKeysClient(t *testing.T) {
+ _, err := newServerConfig(&config.Config{GitlabUrl: "ftp://localhost"})
+
+ require.Error(t, err)
+ require.Equal(t, "failed to initialize GitLab client: Error creating http client: unknown GitLab URL prefix", err.Error())
+}
+
+func TestFailedGetAuthKey(t *testing.T) {
+ testhelper.PrepareTestRootDir(t)
+
+ srvCfg := config.ServerConfig{
+ Listen: "127.0.0.1",
+ ConcurrentSessionsLimit: 1,
+ HostKeyFiles: []string{
+ path.Join(testhelper.TestRoot, "certs/valid/server.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid-path.key"),
+ path.Join(testhelper.TestRoot, "certs/invalid/server.crt"),
+ },
+ }
+
+ cfg, err := newServerConfig(
+ &config.Config{GitlabUrl: "http://localhost", User: "user", Server: srvCfg},
+ )
+ require.NoError(t, err)
+
+ testCases := []struct {
+ desc string
+ user string
+ key ssh.PublicKey
+ expectedError string
+ }{
+ {
+ desc: "wrong user",
+ user: "wrong-user",
+ key: rsaPublicKey(t),
+ expectedError: "unknown user",
+ }, {
+ desc: "prohibited dsa key",
+ user: "user",
+ key: dsaPublicKey(t),
+ expectedError: "DSA is prohibited",
+ }, {
+ desc: "API error",
+ user: "user",
+ key: rsaPublicKey(t),
+ expectedError: "Internal API unreachable",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ _, err = cfg.getAuthKey(context.Background(), tc.user, tc.key)
+ require.Error(t, err)
+ require.Equal(t, tc.expectedError, err.Error())
+ })
+ }
+}
+
+func rsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
+
+func dsaPublicKey(t *testing.T) ssh.PublicKey {
+ privateKey := new(dsa.PrivateKey)
+ params := new(dsa.Parameters)
+ require.NoError(t, dsa.GenerateParameters(params, rand.Reader, dsa.L1024N160))
+
+ privateKey.PublicKey.Parameters = *params
+ require.NoError(t, dsa.GenerateKey(privateKey, rand.Reader))
+
+ publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ require.NoError(t, err)
+
+ return publicKey
+}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index 22cb715..b8e8625 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -2,13 +2,16 @@ package sshd
import (
"context"
+ "errors"
"fmt"
+ "reflect"
+ "gitlab.com/gitlab-org/labkit/log"
"golang.org/x/crypto/ssh"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command"
- "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
+ shellCmd "gitlab.com/gitlab-org/gitlab-shell/cmd/gitlab-shell/command"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/disallowedcommand"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
@@ -39,16 +42,27 @@ type exitStatusReq struct {
}
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
+ ctxlog := log.ContextLogger(ctx)
+
+ ctxlog.Debug("session: handle: entering request loop")
+
for req := range requests {
+ sessionLog := ctxlog.WithFields(log.Fields{
+ "bytesize": len(req.Payload),
+ "type": req.Type,
+ "want_reply": req.WantReply,
+ })
+ sessionLog.Debug("session: handle: request received")
+
var shouldContinue bool
switch req.Type {
case "env":
- shouldContinue = s.handleEnv(req)
+ shouldContinue = s.handleEnv(ctx, req)
case "exec":
shouldContinue = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
- s.exit(s.handleShell(ctx, req))
+ s.exit(ctx, s.handleShell(ctx, req))
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
@@ -57,18 +71,23 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
}
+ sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed")
+
if !shouldContinue {
s.channel.Close()
break
}
}
+
+ ctxlog.Debug("session: handle: exiting request loop")
}
-func (s *session) handleEnv(req *ssh.Request) bool {
+func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
+ log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
return false
}
@@ -84,6 +103,10 @@ func (s *session) handleEnv(req *ssh.Request) bool {
req.Reply(accepted, []byte{})
}
+ log.WithContextFields(
+ ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
+ ).Debug("session: handleEnv: processed")
+
return true
}
@@ -95,7 +118,8 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
s.execCmd = execRequest.Command
- s.exit(s.handleShell(ctx, req))
+ s.exit(ctx, s.handleShell(ctx, req))
+
return false
}
@@ -104,19 +128,11 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
req.Reply(true, []byte{})
}
- args := &commandargs.Shell{
- GitlabKeyId: s.gitlabKeyId,
- Env: sshenv.Env{
- IsSSHConnection: true,
- OriginalCommand: s.execCmd,
- GitProtocolVersion: s.gitProtocolVersion,
- RemoteAddr: s.remoteAddr,
- },
- }
-
- if err := args.ParseCommand(s.execCmd); err != nil {
- s.toStderr("Failed to parse command: %v\n", err.Error())
- return 128
+ env := sshenv.Env{
+ IsSSHConnection: true,
+ OriginalCommand: s.execCmd,
+ GitProtocolVersion: s.gitProtocolVersion,
+ RemoteAddr: s.remoteAddr,
}
rw := &readwriter.ReadWriter{
@@ -125,25 +141,37 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
ErrOut: s.channel.Stderr(),
}
- cmd := command.BuildShellCommand(args, s.cfg, rw)
- if cmd == nil {
- s.toStderr("Unknown command: %v\n", args.CommandType)
+ cmd, err := shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
+ if err != nil {
+ if !errors.Is(err, disallowedcommand.Error) {
+ s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
+ }
+ s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
return 128
}
+ cmdName := reflect.TypeOf(cmd).String()
+ ctxlog := log.ContextLogger(ctx)
+ ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command")
+
if err := cmd.Execute(ctx); err != nil {
- s.toStderr("remote: ERROR: %v\n", err.Error())
+ s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
return 1
}
+ ctxlog.Info("session: handleShell: command executed successfully")
+
return 0
}
-func (s *session) toStderr(format string, args ...interface{}) {
- fmt.Fprintf(s.channel.Stderr(), format, args...)
+func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
+ out := fmt.Sprintf(format, args...)
+ log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output")
+ fmt.Fprint(s.channel.Stderr(), out)
}
-func (s *session) exit(status uint32) {
+func (s *session) exit(ctx context.Context, status uint32) {
+ log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
s.channel.CloseWrite()
diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go
new file mode 100644
index 0000000..f135825
--- /dev/null
+++ b/internal/sshd/session_test.go
@@ -0,0 +1,189 @@
+package sshd
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+
+ "gitlab.com/gitlab-org/gitlab-shell/client/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/config"
+)
+
+type fakeChannel struct {
+ stdErr io.ReadWriter
+ sentRequestName string
+ sentRequestPayload []byte
+}
+
+func (f *fakeChannel) Read(data []byte) (int, error) {
+ return 0, nil
+}
+
+func (f *fakeChannel) Write(data []byte) (int, error) {
+ return 0, nil
+}
+
+func (f *fakeChannel) Close() error {
+ return nil
+}
+
+func (f *fakeChannel) CloseWrite() error {
+ return nil
+}
+
+func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ f.sentRequestName = name
+ f.sentRequestPayload = payload
+
+ return true, nil
+}
+
+func (f *fakeChannel) Stderr() io.ReadWriter {
+ return f.stdErr
+}
+
+var requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`))
+ },
+ },
+}
+
+func TestHandleEnv(t *testing.T) {
+ testCases := []struct {
+ desc string
+ payload []byte
+ expectedProtocolVersion string
+ expectedResult bool
+ }{
+ {
+ desc: "invalid payload",
+ payload: []byte("invalid"),
+ expectedProtocolVersion: "1",
+ expectedResult: false,
+ }, {
+ desc: "valid payload",
+ payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
+ expectedProtocolVersion: "2",
+ expectedResult: true,
+ }, {
+ desc: "valid payload with forbidden env var",
+ payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
+ expectedProtocolVersion: "1",
+ expectedResult: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ s := &session{gitProtocolVersion: "1"}
+ r := &ssh.Request{Payload: tc.payload}
+
+ require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
+ require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
+ })
+ }
+}
+
+func TestHandleExec(t *testing.T) {
+ testCases := []struct {
+ desc string
+ payload []byte
+ expectedExecCmd string
+ sentRequestName string
+ sentRequestPayload []byte
+ }{
+ {
+ desc: "invalid payload",
+ payload: []byte("invalid"),
+ expectedExecCmd: "",
+ sentRequestName: "",
+ }, {
+ desc: "valid payload",
+ payload: ssh.Marshal(execRequest{Command: "discover"}),
+ expectedExecCmd: "discover",
+ sentRequestName: "exit-status",
+ sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
+ },
+ }
+
+ url := testserver.StartHttpServer(t, requests)
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ f := &fakeChannel{stdErr: out}
+ s := &session{
+ gitlabKeyId: "root",
+ channel: f,
+ cfg: &config.Config{GitlabUrl: url},
+ }
+ r := &ssh.Request{Payload: tc.payload}
+
+ require.Equal(t, false, s.handleExec(context.Background(), r))
+ require.Equal(t, tc.sentRequestName, f.sentRequestName)
+ require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
+ })
+ }
+}
+
+func TestHandleShell(t *testing.T) {
+ testCases := []struct {
+ desc string
+ cmd string
+ errMsg string
+ gitlabKeyId string
+ expectedExitCode uint32
+ }{
+ {
+ desc: "fails to parse command",
+ cmd: `\`,
+ errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
+ gitlabKeyId: "root",
+ expectedExitCode: 128,
+ }, {
+ desc: "specified command is unknown",
+ cmd: "unknown-command",
+ errMsg: "Unknown command: unknown-command\n",
+ gitlabKeyId: "root",
+ expectedExitCode: 128,
+ }, {
+ desc: "fails to parse command",
+ cmd: "discover",
+ gitlabKeyId: "",
+ errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
+ expectedExitCode: 1,
+ }, {
+ desc: "fails to parse command",
+ cmd: "discover",
+ errMsg: "",
+ gitlabKeyId: "root",
+ expectedExitCode: 0,
+ },
+ }
+
+ url := testserver.StartHttpServer(t, requests)
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ s := &session{
+ gitlabKeyId: tc.gitlabKeyId,
+ execCmd: tc.cmd,
+ channel: &fakeChannel{stdErr: out},
+ cfg: &config.Config{GitlabUrl: url},
+ }
+ r := &ssh.Request{}
+
+ require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
+ require.Equal(t, tc.errMsg, out.String())
+ })
+ }
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index de5fbd4..19fa661 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -2,13 +2,9 @@ package sshd
import (
"context"
- "encoding/base64"
- "errors"
"fmt"
"net"
"net/http"
- "os"
- "strconv"
"sync"
"time"
@@ -16,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
@@ -35,44 +30,24 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
- hostKeys []ssh.Signer
- authorizedKeysClient *authorizedkeys.Client
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ serverConfig *serverConfig
}
func NewServer(cfg *config.Config) (*Server, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ serverConfig, err := newServerConfig(cfg)
if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ return nil, err
}
- var hostKeys []ssh.Signer
- for _, filename := range cfg.Server.HostKeyFiles {
- keyRaw, err := os.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
-
- hostKeys = append(hostKeys, key)
- }
- if len(hostKeys) == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
-
- return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+ return &Server{Config: cfg, serverConfig: serverConfig}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
- if err := s.listen(); err != nil {
+ if err := s.listen(ctx); err != nil {
return err
}
defer s.listener.Close()
@@ -110,7 +85,7 @@ func (s *Server) MonitoringServeMux() *http.ServeMux {
return mux
}
-func (s *Server) listen() error {
+func (s *Server) listen(ctx context.Context) error {
sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
return fmt.Errorf("failed to listen for connection: %w", err)
@@ -122,10 +97,10 @@ func (s *Server) listen() error {
ReadHeaderTimeout: ProxyHeaderTimeout,
}
- log.Info("Proxy protocol is enabled")
+ log.ContextLogger(ctx).Info("Proxy protocol is enabled")
}
- log.WithFields(log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections")
+ log.WithContextFields(ctx, log.Fields{"tcp_address": sshListener.Addr().String()}).Info("Listening for SSH connections")
s.listener = sshListener
@@ -142,7 +117,7 @@ func (s *Server) serve(ctx context.Context) {
break
}
- log.WithError(err).Warn("Failed to accept connection")
+ log.ContextLogger(ctx).WithError(err).Warn("Failed to accept connection")
continue
}
@@ -168,57 +143,29 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
- sshCfg := &ssh.ServerConfig{
- PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != s.Config.User {
- return nil, errors.New("unknown user")
- }
- if key.Type() == ssh.KeyAlgoDSA {
- return nil, errors.New("DSA is prohibited")
- }
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
- if err != nil {
- return nil, err
- }
-
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "key-id": strconv.FormatInt(res.Id, 10),
- },
- }, nil
- },
- }
-
- for _, key := range s.hostKeys {
- sshCfg.AddHostKey(key)
- }
-
- return sshCfg
-}
-
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
defer nconn.Close()
+ ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
+ defer cancel()
+
+ ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
+
// Prevent a panic in a single connection from taking out the whole server
defer func() {
if err := recover(); err != nil {
- log.WithFields(log.Fields{"recovered_error": err}).Warnf("panic handling session from %s", remoteAddr)
+ ctxlog.Warn("panic handling session")
}
}()
- ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
- defer cancel()
+ ctxlog.Info("server: handleConn: start")
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
- log.WithError(err).Info("Failed to initialize SSH connection")
+ ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection")
return
}
@@ -235,4 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
session.handle(ctx, requests)
})
+
+ ctxlog.Info("server: handleConn: done")
}
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 32946af..71f7733 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -104,6 +104,24 @@ func TestLivenessProbe(t *testing.T) {
require.Equal(t, 200, r.Result().StatusCode)
}
+func TestInvalidClientConfig(t *testing.T) {
+ setupServer(t)
+
+ cfg := clientConfig(t)
+ cfg.User = "unknown"
+ _, err := ssh.Dial("tcp", serverUrl, cfg)
+ require.Error(t, err)
+}
+
+func TestInvalidServerConfig(t *testing.T) {
+ s := &Server{Config: &config.Config{Server: config.ServerConfig{Listen: "invalid"}}}
+ err := s.ListenAndServe(context.Background())
+
+ require.Error(t, err)
+ require.Equal(t, "failed to listen for connection: listen tcp: address invalid: missing port in address", err.Error())
+ require.Nil(t, s.Shutdown())
+}
+
func setupServer(t *testing.T) *Server {
t.Helper()