summaryrefslogtreecommitdiff
path: root/go/internal/command/commandargs
diff options
context:
space:
mode:
Diffstat (limited to 'go/internal/command/commandargs')
-rw-r--r--go/internal/command/commandargs/command_args.go43
-rw-r--r--go/internal/command/commandargs/command_args_test.go69
2 files changed, 87 insertions, 25 deletions
diff --git a/go/internal/command/commandargs/command_args.go b/go/internal/command/commandargs/command_args.go
index e801889..7e241ea 100644
--- a/go/internal/command/commandargs/command_args.go
+++ b/go/internal/command/commandargs/command_args.go
@@ -4,6 +4,8 @@ import (
"errors"
"os"
"regexp"
+
+ "github.com/mattn/go-shellwords"
)
type CommandType string
@@ -11,6 +13,7 @@ type CommandType string
const (
Discover CommandType = "discover"
TwoFactorRecover CommandType = "2fa_recovery_codes"
+ ReceivePack CommandType = "git-receive-pack"
)
var (
@@ -21,7 +24,7 @@ var (
type CommandArgs struct {
GitlabUsername string
GitlabKeyId string
- SshCommand string
+ SshArgs []string
CommandType CommandType
}
@@ -30,12 +33,15 @@ func Parse(arguments []string) (*CommandArgs, error) {
return nil, errors.New("Only ssh allowed")
}
- info := &CommandArgs{}
+ args := &CommandArgs{}
+ args.parseWho(arguments)
- info.parseWho(arguments)
- info.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND"))
+ if err := args.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND")); err != nil {
+ return nil, errors.New("Invalid ssh command")
+ }
+ args.defineCommandType()
- return info, nil
+ return args, nil
}
func (c *CommandArgs) parseWho(arguments []string) {
@@ -74,14 +80,29 @@ func tryParseUsername(argument string) string {
return ""
}
-func (c *CommandArgs) parseCommand(commandString string) {
- c.SshCommand = commandString
+func (c *CommandArgs) parseCommand(commandString string) error {
+ args, err := shellwords.Parse(commandString)
+ if err != nil {
+ return err
+ }
- if commandString == "" {
- c.CommandType = Discover
+ // Handle Git for Windows 2.14 using "git upload-pack" instead of git-upload-pack
+ if len(args) > 1 && args[0] == "git" {
+ command := args[0] + "-" + args[1]
+ commandArgs := args[2:]
+
+ args = append([]string{command}, commandArgs...)
}
- if CommandType(commandString) == TwoFactorRecover {
- c.CommandType = TwoFactorRecover
+ c.SshArgs = args
+
+ return nil
+}
+
+func (c *CommandArgs) defineCommandType() {
+ if len(c.SshArgs) == 0 {
+ c.CommandType = Discover
+ } else {
+ c.CommandType = CommandType(c.SshArgs[0])
}
}
diff --git a/go/internal/command/commandargs/command_args_test.go b/go/internal/command/commandargs/command_args_test.go
index 10c46fe..01202c0 100644
--- a/go/internal/command/commandargs/command_args_test.go
+++ b/go/internal/command/commandargs/command_args_test.go
@@ -3,7 +3,8 @@ package commandargs
import (
"testing"
- "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
"gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper"
)
@@ -22,23 +23,16 @@ func TestParseSuccess(t *testing.T) {
"SSH_CONNECTION": "1",
"SSH_ORIGINAL_COMMAND": "",
},
- expectedArgs: &CommandArgs{CommandType: Discover},
+ expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover},
},
{
- desc: "It passes on the original ssh command from the environment",
- environment: map[string]string{
- "SSH_CONNECTION": "1",
- "SSH_ORIGINAL_COMMAND": "hello world",
- },
- expectedArgs: &CommandArgs{SshCommand: "hello world"},
- }, {
desc: "It finds the key id in any passed arguments",
environment: map[string]string{
"SSH_CONNECTION": "1",
"SSH_ORIGINAL_COMMAND": "",
},
arguments: []string{"hello", "key-123"},
- expectedArgs: &CommandArgs{CommandType: Discover, GitlabKeyId: "123"},
+ expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123"},
}, {
desc: "It finds the username in any passed arguments",
environment: map[string]string{
@@ -46,7 +40,42 @@ func TestParseSuccess(t *testing.T) {
"SSH_ORIGINAL_COMMAND": "",
},
arguments: []string{"hello", "username-jane-doe"},
- expectedArgs: &CommandArgs{CommandType: Discover, GitlabUsername: "jane-doe"},
+ expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe"},
+ }, {
+ desc: "It parses 2fa_recovery_codes command",
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes",
+ },
+ expectedArgs: &CommandArgs{SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover},
+ }, {
+ desc: "It parses git-receive-pack command",
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "git-receive-pack group/repo",
+ },
+ expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: "It parses git-receive-pack command and a project with single quotes",
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "git receive-pack 'group/repo'",
+ },
+ expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: `It parses "git receive-pack" command`,
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git receive-pack "group/repo"`,
+ },
+ expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
+ }, {
+ desc: `It parses a command followed by control characters`,
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git-receive-pack group/repo; any command`,
+ },
+ expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack},
},
}
@@ -57,8 +86,8 @@ func TestParseSuccess(t *testing.T) {
result, err := Parse(tc.arguments)
- assert.NoError(t, err)
- assert.Equal(t, tc.expectedArgs, result)
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedArgs, result)
})
}
}
@@ -67,7 +96,19 @@ func TestParseFailure(t *testing.T) {
t.Run("It fails if SSH connection is not set", func(t *testing.T) {
_, err := Parse([]string{})
- assert.Error(t, err, "Only ssh allowed")
+ require.Error(t, err, "Only ssh allowed")
})
+ t.Run("It fails if SSH command is invalid", func(t *testing.T) {
+ environment := map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": `git receive-pack "`,
+ }
+ restoreEnv := testhelper.TempEnv(environment)
+ defer restoreEnv()
+
+ _, err := Parse([]string{})
+
+ require.Error(t, err, "Invalid ssh command")
+ })
}