diff options
Diffstat (limited to 'go')
28 files changed, 1212 insertions, 102 deletions
diff --git a/go/internal/command/command.go b/go/internal/command/command.go index 0ceb7fc..9589f2f 100644 --- a/go/internal/command/command.go +++ b/go/internal/command/command.go @@ -5,6 +5,7 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/receivepack" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) @@ -33,6 +34,8 @@ func buildCommand(args *commandargs.CommandArgs, config *config.Config, readWrit return &discover.Command{Config: config, Args: args, ReadWriter: readWriter} case commandargs.TwoFactorRecover: return &twofactorrecover.Command{Config: config, Args: args, ReadWriter: readWriter} + case commandargs.ReceivePack: + return &receivepack.Command{Config: config, Args: args, ReadWriter: readWriter} } return nil diff --git a/go/internal/command/command_test.go b/go/internal/command/command_test.go index 228dc7a..99069c7 100644 --- a/go/internal/command/command_test.go +++ b/go/internal/command/command_test.go @@ -3,9 +3,11 @@ package command import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/receivepack" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" @@ -58,6 +60,19 @@ func TestNew(t *testing.T) { }, expectedType: &twofactorrecover.Command{}, }, + { + desc: "it returns a ReceivePack command if the feature is enabled", + arguments: []string{}, + config: &config.Config{ + GitlabUrl: "http+unix://gitlab.socket", + Migration: config.MigrationConfig{Enabled: true, Features: []string{"git-receive-pack"}}, + }, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git-receive-pack", + }, + expectedType: &receivepack.Command{}, + }, } for _, tc := range testCases { @@ -67,8 +82,8 @@ func TestNew(t *testing.T) { command, err := New(tc.arguments, tc.config, nil) - assert.NoError(t, err) - assert.IsType(t, tc.expectedType, command) + require.NoError(t, err) + require.IsType(t, tc.expectedType, command) }) } } @@ -80,6 +95,6 @@ func TestFailingNew(t *testing.T) { _, err := New([]string{}, &config.Config{}, nil) - assert.Error(t, err, "Only ssh allowed") + require.Error(t, err, "Only ssh allowed") }) } diff --git a/go/internal/command/commandargs/command_args.go b/go/internal/command/commandargs/command_args.go index e801889..7e241ea 100644 --- a/go/internal/command/commandargs/command_args.go +++ b/go/internal/command/commandargs/command_args.go @@ -4,6 +4,8 @@ import ( "errors" "os" "regexp" + + "github.com/mattn/go-shellwords" ) type CommandType string @@ -11,6 +13,7 @@ type CommandType string const ( Discover CommandType = "discover" TwoFactorRecover CommandType = "2fa_recovery_codes" + ReceivePack CommandType = "git-receive-pack" ) var ( @@ -21,7 +24,7 @@ var ( type CommandArgs struct { GitlabUsername string GitlabKeyId string - SshCommand string + SshArgs []string CommandType CommandType } @@ -30,12 +33,15 @@ func Parse(arguments []string) (*CommandArgs, error) { return nil, errors.New("Only ssh allowed") } - info := &CommandArgs{} + args := &CommandArgs{} + args.parseWho(arguments) - info.parseWho(arguments) - info.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND")) + if err := args.parseCommand(os.Getenv("SSH_ORIGINAL_COMMAND")); err != nil { + return nil, errors.New("Invalid ssh command") + } + args.defineCommandType() - return info, nil + return args, nil } func (c *CommandArgs) parseWho(arguments []string) { @@ -74,14 +80,29 @@ func tryParseUsername(argument string) string { return "" } -func (c *CommandArgs) parseCommand(commandString string) { - c.SshCommand = commandString +func (c *CommandArgs) parseCommand(commandString string) error { + args, err := shellwords.Parse(commandString) + if err != nil { + return err + } - if commandString == "" { - c.CommandType = Discover + // Handle Git for Windows 2.14 using "git upload-pack" instead of git-upload-pack + if len(args) > 1 && args[0] == "git" { + command := args[0] + "-" + args[1] + commandArgs := args[2:] + + args = append([]string{command}, commandArgs...) } - if CommandType(commandString) == TwoFactorRecover { - c.CommandType = TwoFactorRecover + c.SshArgs = args + + return nil +} + +func (c *CommandArgs) defineCommandType() { + if len(c.SshArgs) == 0 { + c.CommandType = Discover + } else { + c.CommandType = CommandType(c.SshArgs[0]) } } diff --git a/go/internal/command/commandargs/command_args_test.go b/go/internal/command/commandargs/command_args_test.go index 10c46fe..01202c0 100644 --- a/go/internal/command/commandargs/command_args_test.go +++ b/go/internal/command/commandargs/command_args_test.go @@ -3,7 +3,8 @@ package commandargs import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" ) @@ -22,23 +23,16 @@ func TestParseSuccess(t *testing.T) { "SSH_CONNECTION": "1", "SSH_ORIGINAL_COMMAND": "", }, - expectedArgs: &CommandArgs{CommandType: Discover}, + expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover}, }, { - desc: "It passes on the original ssh command from the environment", - environment: map[string]string{ - "SSH_CONNECTION": "1", - "SSH_ORIGINAL_COMMAND": "hello world", - }, - expectedArgs: &CommandArgs{SshCommand: "hello world"}, - }, { desc: "It finds the key id in any passed arguments", environment: map[string]string{ "SSH_CONNECTION": "1", "SSH_ORIGINAL_COMMAND": "", }, arguments: []string{"hello", "key-123"}, - expectedArgs: &CommandArgs{CommandType: Discover, GitlabKeyId: "123"}, + expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover, GitlabKeyId: "123"}, }, { desc: "It finds the username in any passed arguments", environment: map[string]string{ @@ -46,7 +40,42 @@ func TestParseSuccess(t *testing.T) { "SSH_ORIGINAL_COMMAND": "", }, arguments: []string{"hello", "username-jane-doe"}, - expectedArgs: &CommandArgs{CommandType: Discover, GitlabUsername: "jane-doe"}, + expectedArgs: &CommandArgs{SshArgs: []string{}, CommandType: Discover, GitlabUsername: "jane-doe"}, + }, { + desc: "It parses 2fa_recovery_codes command", + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes", + }, + expectedArgs: &CommandArgs{SshArgs: []string{"2fa_recovery_codes"}, CommandType: TwoFactorRecover}, + }, { + desc: "It parses git-receive-pack command", + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git-receive-pack group/repo", + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: "It parses git-receive-pack command and a project with single quotes", + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": "git receive-pack 'group/repo'", + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: `It parses "git receive-pack" command`, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git receive-pack "group/repo"`, + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, + }, { + desc: `It parses a command followed by control characters`, + environment: map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git-receive-pack group/repo; any command`, + }, + expectedArgs: &CommandArgs{SshArgs: []string{"git-receive-pack", "group/repo"}, CommandType: ReceivePack}, }, } @@ -57,8 +86,8 @@ func TestParseSuccess(t *testing.T) { result, err := Parse(tc.arguments) - assert.NoError(t, err) - assert.Equal(t, tc.expectedArgs, result) + require.NoError(t, err) + require.Equal(t, tc.expectedArgs, result) }) } } @@ -67,7 +96,19 @@ func TestParseFailure(t *testing.T) { t.Run("It fails if SSH connection is not set", func(t *testing.T) { _, err := Parse([]string{}) - assert.Error(t, err, "Only ssh allowed") + require.Error(t, err, "Only ssh allowed") }) + t.Run("It fails if SSH command is invalid", func(t *testing.T) { + environment := map[string]string{ + "SSH_CONNECTION": "1", + "SSH_ORIGINAL_COMMAND": `git receive-pack "`, + } + restoreEnv := testhelper.TempEnv(environment) + defer restoreEnv() + + _, err := Parse([]string{}) + + require.Error(t, err, "Invalid ssh command") + }) } diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go index 40c0d4e..284610a 100644 --- a/go/internal/command/discover/discover_test.go +++ b/go/internal/command/discover/discover_test.go @@ -7,7 +7,6 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" @@ -45,8 +44,7 @@ var ( ) func TestExecute(t *testing.T) { - cleanup, url, err := testserver.StartSocketHttpServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartSocketHttpServer(t, requests) defer cleanup() testCases := []struct { @@ -87,15 +85,14 @@ func TestExecute(t *testing.T) { err := cmd.Execute() - assert.NoError(t, err) - assert.Equal(t, tc.expectedOutput, buffer.String()) + require.NoError(t, err) + require.Equal(t, tc.expectedOutput, buffer.String()) }) } } func TestFailingExecute(t *testing.T) { - cleanup, url, err := testserver.StartSocketHttpServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartSocketHttpServer(t, requests) defer cleanup() testCases := []struct { @@ -131,8 +128,8 @@ func TestFailingExecute(t *testing.T) { err := cmd.Execute() - assert.Empty(t, buffer.String()) - assert.EqualError(t, err, tc.expectedError) + require.Empty(t, buffer.String()) + require.EqualError(t, err, tc.expectedError) }) } } diff --git a/go/internal/command/receivepack/customaction.go b/go/internal/command/receivepack/customaction.go new file mode 100644 index 0000000..8623437 --- /dev/null +++ b/go/internal/command/receivepack/customaction.go @@ -0,0 +1,99 @@ +package receivepack + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/accessverifier" +) + +type Request struct { + SecretToken []byte `json:"secret_token"` + Data accessverifier.CustomPayloadData `json:"data"` + Output []byte `json:"output"` +} + +type Response struct { + Result []byte `json:"result"` + Message string `json:"message"` +} + +func (c *Command) processCustomAction(response *accessverifier.Response) error { + data := response.Payload.Data + apiEndpoints := data.ApiEndpoints + + if len(apiEndpoints) == 0 { + return errors.New("Custom action error: Empty API endpoints") + } + + c.displayInfoMessage(data.InfoMessage) + + return c.processApiEndpoints(response) +} + +func (c *Command) displayInfoMessage(infoMessage string) { + messages := strings.Split(infoMessage, "\n") + + for _, msg := range messages { + fmt.Fprintf(c.ReadWriter.ErrOut, "> GitLab: %v\n", msg) + } +} + +func (c *Command) processApiEndpoints(response *accessverifier.Response) error { + client, err := gitlabnet.GetClient(c.Config) + + if err != nil { + return err + } + + data := response.Payload.Data + request := &Request{Data: data} + request.Data.UserId = response.Who + + for _, endpoint := range data.ApiEndpoints { + response, err := c.performRequest(client, endpoint, request) + if err != nil { + return err + } + + if err = c.displayResult(response.Result); err != nil { + return err + } + + // In the context of the git push sequence of events, it's necessary to read + // stdin in order to capture output to pass onto subsequent commands + output, err := ioutil.ReadAll(c.ReadWriter.In) + if err != nil { + return err + } + request.Output = output + } + + return nil +} + +func (c *Command) performRequest(client *gitlabnet.GitlabClient, endpoint string, request *Request) (*Response, error) { + response, err := client.DoRequest(http.MethodPost, endpoint, request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + cr := &Response{} + if err := gitlabnet.ParseJSON(response, cr); err != nil { + return nil, err + } + + return cr, nil +} + +func (c *Command) displayResult(result []byte) error { + _, err := io.Copy(c.ReadWriter.Out, bytes.NewReader(result)) + return err +} diff --git a/go/internal/command/receivepack/customaction_test.go b/go/internal/command/receivepack/customaction_test.go new file mode 100644 index 0000000..80e849c --- /dev/null +++ b/go/internal/command/receivepack/customaction_test.go @@ -0,0 +1,105 @@ +package receivepack + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +func TestCustomReceivePack(t *testing.T) { + repo := "group/repo" + keyId := "1" + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var request *accessverifier.Request + require.NoError(t, json.Unmarshal(b, &request)) + + require.Equal(t, "1", request.KeyId) + + body := map[string]interface{}{ + "status": true, + "gl_id": "1", + "payload": map[string]interface{}{ + "action": "geo_proxy_to_primary", + "data": map[string]interface{}{ + "api_endpoints": []string{"/geo/proxy_git_push_ssh/info_refs", "/geo/proxy_git_push_ssh/push"}, + "gl_username": "custom", + "primary_repo": "https://repo/path", + "info_message": "info_message\none more message", + }, + }, + } + w.WriteHeader(http.StatusMultipleChoices) + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + { + Path: "/geo/proxy_git_push_ssh/info_refs", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var request *Request + require.NoError(t, json.Unmarshal(b, &request)) + + require.Equal(t, request.Data.UserId, "key-"+keyId) + require.Empty(t, request.Output) + + err = json.NewEncoder(w).Encode(Response{Result: []byte("custom")}) + require.NoError(t, err) + }, + }, + { + Path: "/geo/proxy_git_push_ssh/push", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var request *Request + require.NoError(t, json.Unmarshal(b, &request)) + + require.Equal(t, request.Data.UserId, "key-"+keyId) + require.Equal(t, "input", string(request.Output)) + + err = json.NewEncoder(w).Encode(Response{Result: []byte("output")}) + require.NoError(t, err) + }, + }, + } + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + defer cleanup() + + outBuf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + input := bytes.NewBufferString("input") + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.CommandArgs{GitlabKeyId: keyId, CommandType: commandargs.ReceivePack, SshArgs: []string{"git-receive-pack", repo}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: errBuf, Out: outBuf, In: input}, + } + + require.NoError(t, cmd.Execute()) + + // expect printing of info message, "custom" string from the first request + // and "output" string from the second request + require.Equal(t, "> GitLab: info_message\n> GitLab: one more message\n", errBuf.String()) + require.Equal(t, "customoutput", outBuf.String()) +} diff --git a/go/internal/command/receivepack/gitalycall.go b/go/internal/command/receivepack/gitalycall.go new file mode 100644 index 0000000..22652d7 --- /dev/null +++ b/go/internal/command/receivepack/gitalycall.go @@ -0,0 +1,47 @@ +package receivepack + +import ( + "context" + + "google.golang.org/grpc" + + pb "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb" + "gitlab.com/gitlab-org/gitaly/client" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/handler" +) + +func (c *Command) performGitalyCall(response *accessverifier.Response) error { + gc := &handler.GitalyCommand{ + Config: c.Config, + ServiceName: string(commandargs.ReceivePack), + Address: response.Gitaly.Address, + Token: response.Gitaly.Token, + } + + repo := response.Gitaly.Repo + request := &pb.SSHReceivePackRequest{ + Repository: &pb.Repository{ + StorageName: repo.StorageName, + RelativePath: repo.RelativePath, + GitObjectDirectory: repo.GitObjectDirectory, + GitAlternateObjectDirectories: repo.GitAlternateObjectDirectories, + GlRepository: repo.RepoName, + GlProjectPath: repo.ProjectPath, + }, + GlId: response.UserId, + GlRepository: response.Repo, + GlUsername: response.Username, + GitProtocol: response.GitProtocol, + GitConfigOptions: response.GitConfigOptions, + } + + return gc.RunGitalyCommand(func(ctx context.Context, conn *grpc.ClientConn) (int32, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + rw := c.ReadWriter + return client.ReceivePack(ctx, conn, rw.In, rw.Out, rw.ErrOut, request) + }) +} diff --git a/go/internal/command/receivepack/gitalycall_test.go b/go/internal/command/receivepack/gitalycall_test.go new file mode 100644 index 0000000..0914be6 --- /dev/null +++ b/go/internal/command/receivepack/gitalycall_test.go @@ -0,0 +1,40 @@ +package receivepack + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper/requesthandlers" +) + +func TestReceivePack(t *testing.T) { + gitalyAddress, cleanup := testserver.StartGitalyServer(t) + defer cleanup() + + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + input := &bytes.Buffer{} + + userId := "1" + repo := "group/repo" + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.CommandArgs{GitlabKeyId: userId, CommandType: commandargs.ReceivePack, SshArgs: []string{"git-receive-pack", repo}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, + } + + err := cmd.Execute() + require.NoError(t, err) + + require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String()) +} diff --git a/go/internal/command/receivepack/receivepack.go b/go/internal/command/receivepack/receivepack.go new file mode 100644 index 0000000..d1ff3f8 --- /dev/null +++ b/go/internal/command/receivepack/receivepack.go @@ -0,0 +1,45 @@ +package receivepack + +import ( + "errors" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/shared/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" +) + +var ( + disallowedCommandError = errors.New("> GitLab: Disallowed command") +) + +type Command struct { + Config *config.Config + Args *commandargs.CommandArgs + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Execute() error { + args := c.Args.SshArgs + if len(args) != 2 { + return disallowedCommandError + } + + repo := args[1] + response, err := c.verifyAccess(repo) + if err != nil { + return err + } + + if response.IsCustomAction() { + return c.processCustomAction(response) + } + + return c.performGitalyCall(response) +} + +func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { + cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} + + return cmd.Verify(c.Args.CommandType, repo) +} diff --git a/go/internal/command/receivepack/receivepack_test.go b/go/internal/command/receivepack/receivepack_test.go new file mode 100644 index 0000000..874bac3 --- /dev/null +++ b/go/internal/command/receivepack/receivepack_test.go @@ -0,0 +1,46 @@ +package receivepack + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +func TestForbiddenAccess(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := map[string]interface{}{ + "status": false, + "message": "Disallowed by API call", + } + w.WriteHeader(http.StatusForbidden) + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + } + + url, cleanup := testserver.StartHttpServer(t, requests) + defer cleanup() + + output := &bytes.Buffer{} + input := bytes.NewBufferString("input") + + cmd := &Command{ + Config: &config.Config{GitlabUrl: url}, + Args: &commandargs.CommandArgs{GitlabKeyId: "disallowed", SshArgs: []string{"git-receive-pack", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, + } + + err := cmd.Execute() + require.Equal(t, "Disallowed by API call", err.Error()) +} diff --git a/go/internal/command/shared/accessverifier/accessverifier.go b/go/internal/command/shared/accessverifier/accessverifier.go new file mode 100644 index 0000000..6d13789 --- /dev/null +++ b/go/internal/command/shared/accessverifier/accessverifier.go @@ -0,0 +1,45 @@ +package accessverifier + +import ( + "errors" + "fmt" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/accessverifier" +) + +type Response = accessverifier.Response + +type Command struct { + Config *config.Config + Args *commandargs.CommandArgs + ReadWriter *readwriter.ReadWriter +} + +func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) { + client, err := accessverifier.NewClient(c.Config) + if err != nil { + return nil, err + } + + response, err := client.Verify(c.Args, action, repo) + if err != nil { + return nil, err + } + + c.displayConsoleMessages(response.ConsoleMessages) + + if !response.Success { + return nil, errors.New(response.Message) + } + + return response, nil +} + +func (c *Command) displayConsoleMessages(messages []string) { + for _, msg := range messages { + fmt.Fprintf(c.ReadWriter.ErrOut, "> GitLab: %v\n", msg) + } +} diff --git a/go/internal/command/shared/accessverifier/accessverifier_test.go b/go/internal/command/shared/accessverifier/accessverifier_test.go new file mode 100644 index 0000000..dd95ded --- /dev/null +++ b/go/internal/command/shared/accessverifier/accessverifier_test.go @@ -0,0 +1,82 @@ +package accessverifier + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/accessverifier" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +var ( + repo = "group/repo" + action = commandargs.ReceivePack +) + +func setup(t *testing.T) (*Command, *bytes.Buffer, *bytes.Buffer, func()) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var requestBody *accessverifier.Request + err = json.Unmarshal(b, &requestBody) + require.NoError(t, err) + + if requestBody.KeyId == "1" { + body := map[string]interface{}{ + "gl_console_messages": []string{"console", "message"}, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + } else { + body := map[string]interface{}{ + "status": false, + "message": "missing user", + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + } + }, + }, + } + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + errBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + + readWriter := &readwriter.ReadWriter{Out: outBuf, ErrOut: errBuf} + cmd := &Command{Config: &config.Config{GitlabUrl: url}, ReadWriter: readWriter} + + return cmd, errBuf, outBuf, cleanup +} + +func TestMissingUser(t *testing.T) { + cmd, _, _, cleanup := setup(t) + defer cleanup() + + cmd.Args = &commandargs.CommandArgs{GitlabKeyId: "2"} + _, err := cmd.Verify(action, repo) + + require.Equal(t, "missing user", err.Error()) +} + +func TestConsoleMessages(t *testing.T) { + cmd, errBuf, outBuf, cleanup := setup(t) + defer cleanup() + + cmd.Args = &commandargs.CommandArgs{GitlabKeyId: "1"} + cmd.Verify(action, repo) + + require.Equal(t, "> GitLab: console\n> GitLab: message\n", errBuf.String()) + require.Empty(t, outBuf.String()) +} diff --git a/go/internal/command/twofactorrecover/twofactorrecover_test.go b/go/internal/command/twofactorrecover/twofactorrecover_test.go index bcca12a..6238e0d 100644 --- a/go/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/go/internal/command/twofactorrecover/twofactorrecover_test.go @@ -64,8 +64,7 @@ const ( func TestExecute(t *testing.T) { setup(t) - cleanup, url, err := testserver.StartSocketHttpServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartSocketHttpServer(t, requests) defer cleanup() testCases := []struct { diff --git a/go/internal/gitlabnet/accessverifier/client.go b/go/internal/gitlabnet/accessverifier/client.go new file mode 100644 index 0000000..ebe8545 --- /dev/null +++ b/go/internal/gitlabnet/accessverifier/client.go @@ -0,0 +1,119 @@ +package accessverifier + +import ( + "fmt" + "net/http" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" +) + +const ( + protocol = "ssh" + anyChanges = "_any" +) + +type Client struct { + client *gitlabnet.GitlabClient +} + +type Request struct { + Action commandargs.CommandType `json:"action"` + Repo string `json:"project"` + Changes string `json:"changes"` + Protocol string `json:"protocol"` + KeyId string `json:"key_id,omitempty"` + Username string `json:"username,omitempty"` +} + +type GitalyRepo struct { + StorageName string `json:"storage_name"` + RelativePath string `json:"relative_path"` + GitObjectDirectory string `json:"git_object_directory"` + GitAlternateObjectDirectories []string `json:"git_alternate_object_directories"` + RepoName string `json:"gl_repository"` + ProjectPath string `json:"gl_project_path"` +} + +type Gitaly struct { + Repo GitalyRepo `json:"repository"` + Address string `json:"address"` + Token string `json:"token"` +} + +type CustomPayloadData struct { + ApiEndpoints []string `json:"api_endpoints"` + Username string `json:"gl_username"` + PrimaryRepo string `json:"primary_repo"` + InfoMessage string `json:"info_message"` + UserId string `json:"gl_id,omitempty"` +} + +type CustomPayload struct { + Action string `json:"action"` + Data CustomPayloadData `json:"data"` +} + +type Response struct { + Success bool `json:"status"` + Message string `json:"message"` + Repo string `json:"gl_repository"` + UserId string `json:"gl_id"` + Username string `json:"gl_username"` + GitConfigOptions []string `json:"git_config_options"` + Gitaly Gitaly `json:"gitaly"` + GitProtocol string `json:"git_protocol"` + Payload CustomPayload `json:"payload"` + ConsoleMessages []string `json:"gl_console_messages"` + Who string + StatusCode int +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{client: client}, nil +} + +func (c *Client) Verify(args *commandargs.CommandArgs, action commandargs.CommandType, repo string) (*Response, error) { + request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges} + + if args.GitlabUsername != "" { + request.Username = args.GitlabUsername + } else { + request.KeyId = args.GitlabKeyId + } + + response, err := c.client.Post("/allowed", request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + return parse(response, args) +} + +func parse(hr *http.Response, args *commandargs.CommandArgs) (*Response, error) { + response := &Response{} + if err := gitlabnet.ParseJSON(hr, response); err != nil { + return nil, err + } + + if args.GitlabKeyId != "" { + response.Who = "key-" + args.GitlabKeyId + } else { + response.Who = response.UserId + } + + response.StatusCode = hr.StatusCode + + return response, nil +} + +func (r *Response) IsCustomAction() bool { + return r.StatusCode == http.StatusMultipleChoices +} diff --git a/go/internal/gitlabnet/accessverifier/client_test.go b/go/internal/gitlabnet/accessverifier/client_test.go new file mode 100644 index 0000000..a759919 --- /dev/null +++ b/go/internal/gitlabnet/accessverifier/client_test.go @@ -0,0 +1,208 @@ +package accessverifier + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "path" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" +) + +var ( + repo = "group/private" + action = commandargs.ReceivePack +) + +func buildExpectedResponse(who string) *Response { + response := &Response{ + Success: true, + UserId: "user-1", + Repo: "project-26", + Username: "root", + GitConfigOptions: []string{"option"}, + Gitaly: Gitaly{ + Repo: GitalyRepo{ + StorageName: "default", + RelativePath: "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", + GitObjectDirectory: "path/to/git_object_directory", + GitAlternateObjectDirectories: []string{"path/to/git_alternate_object_directory"}, + RepoName: "project-26", + ProjectPath: repo, + }, + Address: "unix:gitaly.socket", + Token: "token", + }, + GitProtocol: "protocol", + Payload: CustomPayload{}, + ConsoleMessages: []string{"console", "message"}, + Who: who, + StatusCode: 200, + } + + return response +} + +func TestSuccessfulResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + args *commandargs.CommandArgs + who string + }{ + { + desc: "Provide key id within the request", + args: &commandargs.CommandArgs{GitlabKeyId: "1"}, + who: "key-1", + }, { + desc: "Provide username within the request", + args: &commandargs.CommandArgs{GitlabUsername: "first"}, + who: "user-1", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := client.Verify(tc.args, action, repo) + require.NoError(t, err) + + response := buildExpectedResponse(tc.who) + require.Equal(t, response, result) + }) + } +} + +func TestGetCustomAction(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabUsername: "custom"} + result, err := client.Verify(args, action, repo) + require.NoError(t, err) + + response := buildExpectedResponse("user-1") + response.Payload = CustomPayload{ + Action: "geo_proxy_to_primary", + Data: CustomPayloadData{ + ApiEndpoints: []string{"geo/proxy_git_push_ssh/info_refs", "geo/proxy_git_push_ssh/push"}, + Username: "custom", + PrimaryRepo: "https://repo/path", + InfoMessage: "message", + }, + } + response.StatusCode = 300 + + require.True(t, response.IsCustomAction()) + require.Equal(t, response, result) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeId string + expectedError string + }{ + { + desc: "A response with an error message", + fakeId: "2", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeId: "3", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeId: "4", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + args := &commandargs.CommandArgs{GitlabKeyId: tc.fakeId} + resp, err := client.Verify(args, action, repo) + + require.EqualError(t, err, tc.expectedError) + require.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + testDirCleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer testDirCleanup() + + body, err := ioutil.ReadFile(path.Join(testhelper.TestRoot, "responses/allowed.json")) + require.NoError(t, err) + + allowedWithPayloadPath := path.Join(testhelper.TestRoot, "responses/allowed_with_payload.json") + bodyWithPayload, err := ioutil.ReadFile(allowedWithPayloadPath) + require.NoError(t, err) + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var requestBody *Request + require.NoError(t, json.Unmarshal(b, &requestBody)) + + switch requestBody.Username { + case "first": + _, err = w.Write(body) + require.NoError(t, err) + case "second": + errBody := map[string]interface{}{ + "status": false, + "message": "missing user", + } + require.NoError(t, json.NewEncoder(w).Encode(errBody)) + case "custom": + w.WriteHeader(http.StatusMultipleChoices) + _, err = w.Write(bodyWithPayload) + require.NoError(t, err) + } + + switch requestBody.KeyId { + case "1": + _, err = w.Write(body) + require.NoError(t, err) + case "2": + w.WriteHeader(http.StatusForbidden) + errBody := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + require.NoError(t, json.NewEncoder(w).Encode(errBody)) + case "3": + w.Write([]byte("{ \"message\": \"broken json!\"")) + case "4": + w.WriteHeader(http.StatusForbidden) + } + }, + }, + } + + url, cleanup := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client, cleanup +} diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go index 86add04..dacb1d6 100644 --- a/go/internal/gitlabnet/client.go +++ b/go/internal/gitlabnet/client.go @@ -53,8 +53,6 @@ func normalizePath(path string) string { } func newRequest(method, host, path string, data interface{}) (*http.Request, error) { - path = normalizePath(path) - var jsonReader io.Reader if data != nil { jsonData, err := json.Marshal(data) @@ -74,7 +72,7 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, err } func parseError(resp *http.Response) error { - if resp.StatusCode >= 200 && resp.StatusCode <= 299 { + if resp.StatusCode >= 200 && resp.StatusCode <= 399 { return nil } defer resp.Body.Close() @@ -89,14 +87,14 @@ func parseError(resp *http.Response) error { } func (c *GitlabClient) Get(path string) (*http.Response, error) { - return c.doRequest("GET", path, nil) + return c.DoRequest(http.MethodGet, normalizePath(path), nil) } func (c *GitlabClient) Post(path string, data interface{}) (*http.Response, error) { - return c.doRequest("POST", path, data) + return c.DoRequest(http.MethodPost, normalizePath(path), data) } -func (c *GitlabClient) doRequest(method, path string, data interface{}) (*http.Response, error) { +func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.Response, error) { request, err := newRequest(method, c.host, path, data) if err != nil { return nil, err diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go index d817239..cf42195 100644 --- a/go/internal/gitlabnet/client_test.go +++ b/go/internal/gitlabnet/client_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" @@ -71,7 +72,7 @@ func TestClients(t *testing.T) { testCases := []struct { desc string config *config.Config - server func([]testserver.TestRequestHandler) (func(), string, error) + server func(*testing.T, []testserver.TestRequestHandler) (string, func()) }{ { desc: "Socket client", @@ -94,9 +95,8 @@ func TestClients(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - cleanup, url, err := tc.server(requests) + url, cleanup := tc.server(t, requests) defer cleanup() - require.NoError(t, err) tc.config.GitlabUrl = url tc.config.Secret = "sssh, it's a secret" diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go index 8eabdd7..b98a28e 100644 --- a/go/internal/gitlabnet/discover/client_test.go +++ b/go/internal/gitlabnet/discover/client_test.go @@ -128,8 +128,7 @@ func TestErrorResponses(t *testing.T) { } func setup(t *testing.T) (*Client, func()) { - cleanup, url, err := testserver.StartSocketHttpServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartSocketHttpServer(t, requests) client, err := NewClient(&config.Config{GitlabUrl: url}) require.NoError(t, err) diff --git a/go/internal/gitlabnet/httpclient_test.go b/go/internal/gitlabnet/httpclient_test.go index 885a6d1..9b635bd 100644 --- a/go/internal/gitlabnet/httpclient_test.go +++ b/go/internal/gitlabnet/httpclient_test.go @@ -86,8 +86,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) { } func setup(t *testing.T, config *config.Config, requests []testserver.TestRequestHandler) (*GitlabClient, func()) { - cleanup, url, err := testserver.StartHttpServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartHttpServer(t, requests) config.GitlabUrl = url client, err := GetClient(config) diff --git a/go/internal/gitlabnet/httpsclient_test.go b/go/internal/gitlabnet/httpsclient_test.go index b9baad8..04901df 100644 --- a/go/internal/gitlabnet/httpsclient_test.go +++ b/go/internal/gitlabnet/httpsclient_test.go @@ -115,8 +115,7 @@ func setupWithRequests(t *testing.T, config *config.Config) (*GitlabClient, func }, } - cleanup, url, err := testserver.StartHttpsServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartHttpsServer(t, requests) config.GitlabUrl = url client, err := GetClient(config) diff --git a/go/internal/gitlabnet/testserver/gitalyserver.go b/go/internal/gitlabnet/testserver/gitalyserver.go new file mode 100644 index 0000000..141a518 --- /dev/null +++ b/go/internal/gitlabnet/testserver/gitalyserver.go @@ -0,0 +1,63 @@ +package testserver + +import ( + "io/ioutil" + "net" + "os" + "path" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + pb "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb" +) + +type testGitalyServer struct{} + +func (s *testGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error { + req, err := stream.Recv() + + if err != nil { + return err + } + + response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository) + stream.Send(&pb.SSHReceivePackResponse{Stdout: response}) + + return nil +} + +func (s *testGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { + return nil +} + +func (s *testGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error { + return nil +} + +func StartGitalyServer(t *testing.T) (string, func()) { + tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api") + gitalySocketPath := path.Join(tempDir, "gitaly.sock") + + err := os.MkdirAll(filepath.Dir(gitalySocketPath), 0700) + require.NoError(t, err) + + server := grpc.NewServer() + + listener, err := net.Listen("unix", gitalySocketPath) + require.NoError(t, err) + + pb.RegisterSSHServiceServer(server, &testGitalyServer{}) + + go server.Serve(listener) + + gitalySocketUrl := "unix:" + gitalySocketPath + cleanup := func() { + server.Stop() + os.RemoveAll(tempDir) + } + + return gitalySocketUrl, cleanup +} diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go index bf896e6..bf59ce4 100644 --- a/go/internal/gitlabnet/testserver/testserver.go +++ b/go/internal/gitlabnet/testserver/testserver.go @@ -10,6 +10,9 @@ import ( "os" "path" "path/filepath" + "testing" + + "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper" ) @@ -24,15 +27,12 @@ type TestRequestHandler struct { Handler func(w http.ResponseWriter, r *http.Request) } -func StartSocketHttpServer(handlers []TestRequestHandler) (func(), string, error) { - if err := os.MkdirAll(filepath.Dir(testSocket), 0700); err != nil { - return nil, "", err - } +func StartSocketHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + err := os.MkdirAll(filepath.Dir(testSocket), 0700) + require.NoError(t, err) socketListener, err := net.Listen("unix", testSocket) - if err != nil { - return nil, "", err - } + require.NoError(t, err) server := http.Server{ Handler: buildHandler(handlers), @@ -44,30 +44,27 @@ func StartSocketHttpServer(handlers []TestRequestHandler) (func(), string, error url := "http+unix://" + testSocket - return cleanupSocket, url, nil + return url, cleanupSocket } -func StartHttpServer(handlers []TestRequestHandler) (func(), string, error) { +func StartHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { server := httptest.NewServer(buildHandler(handlers)) - return server.Close, server.URL, nil + return server.URL, server.Close } -func StartHttpsServer(handlers []TestRequestHandler) (func(), string, error) { +func StartHttpsServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt") key := path.Join(testhelper.TestRoot, "certs/valid/server.key") server := httptest.NewUnstartedServer(buildHandler(handlers)) cer, err := tls.LoadX509KeyPair(crt, key) - - if err != nil { - return nil, "", err - } + require.NoError(t, err) server.TLS = &tls.Config{Certificates: []tls.Certificate{cer}} server.StartTLS() - return server.Close, server.URL, nil + return server.URL, server.Close } func cleanupSocket() { diff --git a/go/internal/gitlabnet/twofactorrecover/client_test.go b/go/internal/gitlabnet/twofactorrecover/client_test.go index 56f7958..4b15ac5 100644 --- a/go/internal/gitlabnet/twofactorrecover/client_test.go +++ b/go/internal/gitlabnet/twofactorrecover/client_test.go @@ -149,8 +149,7 @@ func TestErrorResponses(t *testing.T) { func setup(t *testing.T) (*Client, func()) { initialize(t) - cleanup, url, err := testserver.StartSocketHttpServer(requests) - require.NoError(t, err) + url, cleanup := testserver.StartSocketHttpServer(t, requests) client, err := NewClient(&config.Config{GitlabUrl: url}) require.NoError(t, err) diff --git a/go/internal/handler/exec.go b/go/internal/handler/exec.go index ee7b4a8..671263c 100644 --- a/go/internal/handler/exec.go +++ b/go/internal/handler/exec.go @@ -14,11 +14,29 @@ import ( "google.golang.org/grpc" ) -// GitalyHandlerFunc implementations are responsible for deserializing +// GitalyHandlerFuncWithJSON implementations are responsible for deserializing // the request JSON into a GRPC request message, making an appropriate Gitaly // call with the request, using the provided client, and returning the exit code // or error from the Gitaly call. -type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn, requestJSON string) (int32, error) +type GitalyHandlerFuncWithJSON func(ctx context.Context, client *grpc.ClientConn, requestJSON string) (int32, error) + +// GitalyHandlerFunc implementations are responsible for making +// an appropriate Gitaly call using the provided client and context +// and returning an error from the Gitaly call. +type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn) (int32, error) + +type GitalyConn struct { + ctx context.Context + conn *grpc.ClientConn + close func() +} + +type GitalyCommand struct { + Config *config.Config + ServiceName string + Address string + Token string +} // RunGitalyCommand provides a bootstrap for Gitaly commands executed // through GitLab-Shell. It ensures that logging, tracing and other @@ -26,7 +44,7 @@ type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn, reques // RunGitalyCommand will handle errors internally and call // `os.Exit()` on completion. This method will never return to // the caller. -func RunGitalyCommand(handler GitalyHandlerFunc) { +func RunGitalyCommand(handler GitalyHandlerFuncWithJSON) { exitCode, err := internalRunGitalyCommand(os.Args, handler) if err != nil { @@ -36,10 +54,25 @@ func RunGitalyCommand(handler GitalyHandlerFunc) { os.Exit(exitCode) } -// internalRunGitalyCommand is like RunGitalyCommand, except that since it doesn't -// call os.Exit, we can rely on its deferred handlers executing correctly -func internalRunGitalyCommand(args []string, handler GitalyHandlerFunc) (int, error) { +// RunGitalyCommand provides a bootstrap for Gitaly commands executed +// through GitLab-Shell. It ensures that logging, tracing and other +// common concerns are configured before executing the `handler`. +func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error { + gitalyConn, err := getConn(gc) + + if err != nil { + return err + } + _, err = handler(gitalyConn.ctx, gitalyConn.conn) + + gitalyConn.close() + + return err +} + +// internalRunGitalyCommand runs Gitaly's command by particular Gitaly address and token +func internalRunGitalyCommand(args []string, handler GitalyHandlerFuncWithJSON) (int, error) { if len(args) != 3 { return 1, fmt.Errorf("expected 2 arguments, got %v", args) } @@ -53,13 +86,44 @@ func internalRunGitalyCommand(args []string, handler GitalyHandlerFunc) (int, er return 1, err } + gc := &GitalyCommand{ + Config: cfg, + ServiceName: args[0], + Address: args[1], + Token: os.Getenv("GITALY_TOKEN"), + } + requestJSON := string(args[2]) + + gitalyConn, err := getConn(gc) + + if err != nil { + return 1, err + } + + exitCode, err := handler(gitalyConn.ctx, gitalyConn.conn, requestJSON) + + gitalyConn.close() + + return int(exitCode), err +} + +func getConn(gc *GitalyCommand) (*GitalyConn, error) { + if gc.Address == "" { + return nil, fmt.Errorf("no gitaly_address given") + } + + connOpts := client.DefaultDialOpts + if gc.Token != "" { + connOpts = append(client.DefaultDialOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(gc.Token))) + } + // Use a working directory that won't get removed or unmounted. if err := os.Chdir("/"); err != nil { - return 1, err + return nil, err } // Configure distributed tracing - serviceName := fmt.Sprintf("gitlab-shell-%v", args[0]) + serviceName := fmt.Sprintf("gitlab-shell-%v", gc.ServiceName) closer := tracing.Initialize( tracing.WithServiceName(serviceName), @@ -71,34 +135,21 @@ func internalRunGitalyCommand(args []string, handler GitalyHandlerFunc) (int, er // Processes are spawned as children of the SSH daemon, which tightly // controls environment variables; doing this means we don't have to // enable PermitUserEnvironment - tracing.WithConnectionString(cfg.GitlabTracing), + tracing.WithConnectionString(gc.Config.GitlabTracing), ) - defer closer.Close() ctx, finished := tracing.ExtractFromEnv(context.Background()) - defer finished() - gitalyAddress := args[1] - if gitalyAddress == "" { - return 1, fmt.Errorf("no gitaly_address given") - } - - conn, err := client.Dial(gitalyAddress, dialOpts()) + conn, err := client.Dial(gc.Address, connOpts) if err != nil { - return 1, err + return nil, err } - defer conn.Close() - requestJSON := string(args[2]) - exitCode, err := handler(ctx, conn, requestJSON) - return int(exitCode), err -} - -func dialOpts() []grpc.DialOption { - connOpts := client.DefaultDialOpts - if token := os.Getenv("GITALY_TOKEN"); token != "" { - connOpts = append(client.DefaultDialOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(token))) + finish := func() { + finished() + closer.Close() + conn.Close() } - return connOpts + return &GitalyConn{ctx: ctx, conn: conn, close: finish}, nil } diff --git a/go/internal/testhelper/requesthandlers/requesthandlers.go b/go/internal/testhelper/requesthandlers/requesthandlers.go new file mode 100644 index 0000000..d7e077b --- /dev/null +++ b/go/internal/testhelper/requesthandlers/requesthandlers.go @@ -0,0 +1,40 @@ +package requesthandlers + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testserver.TestRequestHandler { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := map[string]interface{}{ + "status": true, + "gl_id": "1", + "gitaly": map[string]interface{}{ + "repository": map[string]interface{}{ + "storage_name": "storage_name", + "relative_path": "relative_path", + "git_object_directory": "path/to/git_object_directory", + "git_alternate_object_directories": []string{"path/to/git_alternate_object_directory"}, + "gl_repository": "group/repo", + "gl_project_path": "group/project-path", + }, + "address": gitalyAddress, + "token": "token", + }, + } + require.NoError(t, json.NewEncoder(w).Encode(body)) + }, + }, + } + + return requests +} diff --git a/go/internal/testhelper/testdata/testroot/responses/allowed.json b/go/internal/testhelper/testdata/testroot/responses/allowed.json new file mode 100644 index 0000000..d0403d9 --- /dev/null +++ b/go/internal/testhelper/testdata/testroot/responses/allowed.json @@ -0,0 +1,22 @@ +{ + "status": true, + "gl_repository": "project-26", + "gl_project_path": "group/private", + "gl_id": "user-1", + "gl_username": "root", + "git_config_options": ["option"], + "gitaly": { + "repository": { + "storage_name": "default", + "relative_path": "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", + "git_object_directory": "path/to/git_object_directory", + "git_alternate_object_directories": ["path/to/git_alternate_object_directory"], + "gl_repository": "project-26", + "gl_project_path": "group/private" + }, + "address": "unix:gitaly.socket", + "token": "token" + }, + "git_protocol": "protocol", + "gl_console_messages": ["console", "message"] +} diff --git a/go/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json b/go/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json new file mode 100644 index 0000000..331c3a9 --- /dev/null +++ b/go/internal/testhelper/testdata/testroot/responses/allowed_with_payload.json @@ -0,0 +1,31 @@ +{ + "status": true, + "gl_repository": "project-26", + "gl_project_path": "group/private", + "gl_id": "user-1", + "gl_username": "root", + "git_config_options": ["option"], + "gitaly": { + "repository": { + "storage_name": "default", + "relative_path": "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", + "git_object_directory": "path/to/git_object_directory", + "git_alternate_object_directories": ["path/to/git_alternate_object_directory"], + "gl_repository": "project-26", + "gl_project_path": "group/private" + }, + "address": "unix:gitaly.socket", + "token": "token" + }, + "payload" : { + "action": "geo_proxy_to_primary", + "data": { + "api_endpoints": ["geo/proxy_git_push_ssh/info_refs", "geo/proxy_git_push_ssh/push"], + "gl_username": "custom", + "primary_repo": "https://repo/path", + "info_message": "message" + } + }, + "git_protocol": "protocol", + "gl_console_messages": ["console", "message"] +} |