diff options
Diffstat (limited to 'go/internal/command/commandargs')
-rw-r--r-- | go/internal/command/commandargs/command_args.go | 43 | ||||
-rw-r--r-- | go/internal/command/commandargs/command_args_test.go | 69 |
2 files changed, 87 insertions, 25 deletions
diff --git a/go/internal/command/commandargs/command_args.go b/go/internal/command/commandargs/command_args.go index e801889..7e241ea 100644 --- a/go/internal/command/commandargs/command_args.go +++ b/go/internal/command/commandargs/command_args.go @@ -4,6 +4,8 @@ import ( "errors" "os" "regexp" + + "github.com/mattn/go-shellwords" ) type CommandType string @@ -11,6 +13,7 @@ type CommandType string const ( Discover CommandType = "discover" TwoFactorRecover CommandType = "2fa_recovery_codes" + ReceivePack CommandType = "git-receive-pack" ) var ( @@ -21,7 +24,7 @@ var ( type CommandArgs struct { GitlabUsername string GitlabKeyId string - SshCommand string + SshArgs []string CommandType CommandType } @@ -30,12 +33,15 @@ func Parse(arguments []string) (*CommandArgs, error) { return nil, errors.New("Only ssh allowed") } - info := &CommandArgs{} + args := &CommandArgs{} + args.parseWho(arguments) - info.parseWho(arguments) - info.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND")) + if err := args.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND")); err != nil { + return nil, errors.New("Invalid ssh command") + } + args.defineCommandType() - return info, nil + return args, nil } func (c *CommandArgs) parseWho(arguments []string) { @@ -74,14 +80,29 @@ func tryParseUsername(argument string) string { return "" } -func (c *CommandArgs) parseCommand(commandString string) { - c.SshCommand = commandString +func (c *CommandArgs) parseCommand(commandString string) error { + args, err := shellwords.Parse(commandString) + if err != nil { + return err + } - if commandString == "" { - c.CommandType = Discover + // Handle Git for Windows 2.14 using "git upload-pack" instead of git-upload-pack + if len(args) > 1 && args[0] == "git" { + command := args[0] + "-" + args[1] + commandArgs := args[2:] + + args = append([]string{command}, commandArgs...) } - if CommandType(commandString) == TwoFactorRecover { - c.CommandType = TwoFactorRecover + c.SshArgs = args + + return nil +} + +func (c *CommandArgs) defineCommandType() { + if len(c.SshArgs) == 0 { + c.CommandType = Discover + } else { + c.CommandType = CommandType(c.SshArgs[0]) } } diff --git a/go/internal/command/commandargs/command_args_test.go b/go/internal/command/commandargs/command_args_test.go index 10c46fe..01202c0 100644 --- a/go/internal/command/commandargs/command_args_test.go +++ b/go/internal/command/commandargs/command_args_test.go @@ -3,7 +3,8 @@ package commandargs import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" ) @@ -22,23 +23,16 @@ func TestParseSuccess(t *testing.T) { "SSH_CONNECTION": "1", "SSH_ORIGINAL_COMMAND": "", }, - expectedArgs: &CommandArgs{CommandType: Discover}, + expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover}, }, { - desc: "It passes on the original ssh command from the environment", - environment: map[string]string{ - "SSH_CONNECTION": "1", - "SSH_ORIGINAL_COMMAND": "hello world", - }, - expectedArgs: &CommandArgs{SshCommand: "hello world"}, - }, { desc: "It finds the key id in any passed arguments", environment: map[string]string{ "SSH_CONNECTION": "1", "SSH_ORIGINAL_COMMAND": "", }, arguments: []string{"hello", "key-123"}, - expectedArgs: &CommandArgs{CommandType: Discover, GitlabKeyId: "123"}, + expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123"}, }, { desc: "It finds the username in any passed arguments", environment: map[string]string{ @@ -46,7 +40,42 @@ func TestParseSuccess(t *testing.T) { "SSH_ORIGINAL_COMMAND": "", }, arguments: []string{"hello", "username-jane-doe"}, - expectedArgs: &CommandArgs{CommandType: Discover, GitlabUsername: "jane-doe"}, + expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe"}, + }, { + desc: "It parses 2fa_recovery_codes command", + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes", + }, + expectedArgs: &CommandArgs{SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover}, + }, { + desc: "It parses git-receive-pack command", + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git-receive-pack group/repo", + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: "It parses git-receive-pack command and a project with single quotes", + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git receive-pack 'group/repo'", + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: `It parses "git receive-pack" command`, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git receive-pack "group/repo"`, + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: `It parses a command followed by control characters`, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git-receive-pack group/repo; any command`, + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, }, } @@ -57,8 +86,8 @@ func TestParseSuccess(t *testing.T) { result, err := Parse(tc.arguments) - assert.NoError(t, err) - assert.Equal(t, tc.expectedArgs, result) + require.NoError(t, err) + require.Equal(t, tc.expectedArgs, result) }) } } @@ -67,7 +96,19 @@ func TestParseFailure(t *testing.T) { t.Run("It fails if SSH connection is not set", func(t *testing.T) { _, err := Parse([]string{}) - assert.Error(t, err, "Only ssh allowed") + require.Error(t, err, "Only ssh allowed") }) + t.Run("It fails if SSH command is invalid", func(t *testing.T) { + environment := map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git receive-pack "`, + } + restoreEnv := testhelper.TempEnv(environment) + defer restoreEnv() + + _, err := Parse([]string{}) + + require.Error(t, err, "Invalid ssh command") + }) } |