diff options
author | Stan Hu <stanhu@gmail.com> | 2020-09-19 03:34:49 -0700 |
---|---|---|
committer | Stan Hu <stanhu@gmail.com> | 2020-09-19 14:00:45 -0700 |
commit | 0590d9198f653ff2170e0f26790056bef4f056fe (patch) | |
tree | dc0d68866ea16ba4f74d441c3aa2048b12fb9e95 /internal/command | |
parent | f100e7e83943b3bb5db232f5bf79a616fdba88f1 (diff) | |
download | gitlab-shell-sh-extract-context-from-env.tar.gz |
Make it possible to propagate correlation ID across processessh-extract-context-from-env
Previously, gitlab-shell did not pass a context through the application.
Correlation IDs were generated down the call stack, since we don't pass
the context around from the start execution.
This has several potential downsides:
1. It's easier for programming mistakes to be made in future which lead
to multiple correlation IDs being generated for a single request.
2. Correlation IDs cannot be passed in from upstream requests
3. Other advantages of context passing, such as distributed tracing is
not possible.
This commit changes the behavior:
1. Extract the correlation ID from the environment at the start of
the application.
2. If no correlation ID exists, generate a random one.
3. Pass the correlation ID to the GitLabNet API requests.
This change also enables other clients of GitLabNet (e.g. Gitaly) to
pass along the correlation ID in the internal API requests
(https://gitlab.com/gitlab-org/gitaly/-/issues/2725).
Fixes https://gitlab.com/gitlab-org/gitlab-shell/-/issues/474
Diffstat (limited to 'internal/command')
29 files changed, 199 insertions, 77 deletions
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go index 7554761..736aeed 100644 --- a/internal/command/authorizedkeys/authorized_keys.go +++ b/internal/command/authorizedkeys/authorized_keys.go @@ -1,6 +1,7 @@ package authorizedkeys import ( + "context" "fmt" "strconv" @@ -17,7 +18,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { // Do and return nothing when the expected and actual user don't match. // This can happen when the user in sshd_config doesn't match the user // trying to login. When nothing is printed, the user will be denied access. @@ -27,15 +28,15 @@ func (c *Command) Execute() error { return nil } - if err := c.printKeyLine(); err != nil { + if err := c.printKeyLine(ctx); err != nil { return err } return nil } -func (c *Command) printKeyLine() error { - response, err := c.getAuthorizedKey() +func (c *Command) printKeyLine(ctx context.Context) error { + response, err := c.getAuthorizedKey(ctx) if err != nil { fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key)) return nil @@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error { return nil } -func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) { +func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) { client, err := authorizedkeys.NewClient(c.Config) if err != nil { return nil, err } - return client.GetByKey(c.Args.Key) + return client.GetByKey(ctx, c.Args.Key) } diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go index e12f4fa..f15c34d 100644 --- a/internal/command/authorizedkeys/authorized_keys_test.go +++ b/internal/command/authorizedkeys/authorized_keys_test.go @@ -2,6 +2,7 @@ package authorizedkeys import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -97,7 +98,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go index ab5f2f8..44f6c47 100644 --- a/internal/command/authorizedprincipals/authorized_principals.go +++ b/internal/command/authorizedprincipals/authorized_principals.go @@ -1,6 +1,7 @@ package authorizedprincipals import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -15,7 +16,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { if err := c.printPrincipalLines(); err != nil { return err } diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go index f11dd0f..ec97b65 100644 --- a/internal/command/authorizedprincipals/authorized_principals_test.go +++ b/internal/command/authorizedprincipals/authorized_principals_test.go @@ -2,6 +2,7 @@ package authorizedprincipals import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -54,7 +55,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/command.go b/internal/command/command.go index 283b4a1..7e0617e 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -1,6 +1,8 @@ package command import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -16,10 +18,13 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/tracing" ) type Command interface { - Execute() error + Execute(ctx context.Context) error } func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) { @@ -35,6 +40,27 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re return nil, disallowedcommand.Error } +// ContextWithCorrelationID() will always return a background Context +// with a correlation ID. It will first attempt to extract the ID from +// an environment variable. If is not available, a random one will be +// generated. +func ContextWithCorrelationID() (context.Context, func()) { + ctx, finished := tracing.ExtractFromEnv(context.Background()) + defer finished() + + correlationID := correlation.ExtractFromContext(ctx) + if correlationID == "" { + correlationID, err := correlation.RandomID() + if err != nil { + log.WithError(err).Warn("unable to generate correlation ID") + } else { + ctx = correlation.ContextWithCorrelation(ctx, correlationID) + } + } + + return ctx, finished +} + func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command { switch e.Name { case executable.GitlabShell: diff --git a/internal/command/command_test.go b/internal/command/command_test.go index db55e7d..cc2364f 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -2,6 +2,7 @@ package command import ( "errors" + "os" "testing" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + "gitlab.com/gitlab-org/labkit/correlation" ) var ( @@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) { }) } } + +func TestContextWithCorrelationID(t *testing.T) { + testCases := []struct { + name string + additionalEnv map[string]string + expectedCorrelationID string + }{ + { + name: "no CORRELATION_ID in environment", + }, + { + name: "CORRELATION_ID in envioonment", + additionalEnv: map[string]string{ + "CORRELATION_ID": "abc123", + }, + expectedCorrelationID: "abc123", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resetEnvironment := addAdditionalEnv(tc.additionalEnv) + defer resetEnvironment() + + ctx, finished := ContextWithCorrelationID() + require.NotNil(t, ctx, "ctx is nil") + require.NotNil(t, finished, "finished is nil") + correlationID := correlation.ExtractFromContext(ctx) + require.NotEmpty(t, correlationID) + + if tc.expectedCorrelationID != "" { + require.Equal(t, tc.expectedCorrelationID, correlationID) + } + defer finished() + }) + } +} + +// addAdditionalEnv will configure additional environment values +// and return a deferrable function to reset the environment to +// it's original state after the test +func addAdditionalEnv(envMap map[string]string) func() { + prevValues := map[string]string{} + unsetValues := []string{} + for k, v := range envMap { + value, exists := os.LookupEnv(k) + if exists { + prevValues[k] = value + } else { + unsetValues = append(unsetValues, k) + } + os.Setenv(k, v) + } + + return func() { + for k, v := range prevValues { + os.Setenv(k, v) + } + + for _, k := range unsetValues { + os.Unsetenv(k) + } + + } +} diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index 3aa7456..822be32 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -1,6 +1,7 @@ package discover import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -15,8 +16,8 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { - response, err := c.getUserInfo() +func (c *Command) Execute(ctx context.Context) error { + response, err := c.getUserInfo(ctx) if err != nil { return fmt.Errorf("Failed to get username: %v", err) } @@ -30,11 +31,11 @@ func (c *Command) Execute() error { return nil } -func (c *Command) getUserInfo() (*discover.Response, error) { +func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { client, err := discover.NewClient(c.Config) if err != nil { return nil, err } - return client.GetByCommandArgs(c.Args) + return client.GetByCommandArgs(ctx, c.Args) } diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index 8edbcb9..5431410 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -2,6 +2,7 @@ package discover import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -83,7 +84,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) @@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, tc.expectedError) diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go index bbc73dc..b04eb0d 100644 --- a/internal/command/healthcheck/healthcheck.go +++ b/internal/command/healthcheck/healthcheck.go @@ -1,6 +1,7 @@ package healthcheck import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" @@ -18,8 +19,8 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { - response, err := c.runCheck() +func (c *Command) Execute(ctx context.Context) error { + response, err := c.runCheck(ctx) if err != nil { return fmt.Errorf("%v: FAILED - %v", apiMessage, err) } @@ -34,13 +35,13 @@ func (c *Command) Execute() error { return nil } -func (c *Command) runCheck() (*healthcheck.Response, error) { +func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) { client, err := healthcheck.NewClient(c.Config) if err != nil { return nil, err } - response, err := client.Check() + response, err := client.Check(ctx) if err != nil { return nil, err } diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go index 7479bcb..d05e563 100644 --- a/internal/command/healthcheck/healthcheck_test.go +++ b/internal/command/healthcheck/healthcheck_test.go @@ -2,6 +2,7 @@ package healthcheck import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -53,7 +54,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String()) @@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Error(t, err, "Redis available via internal API: FAILED") require.Equal(t, "Internal API available: OK\n", buffer.String()) } @@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)") } diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index 2aaac2a..dab69ab 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -1,6 +1,7 @@ package lfsauthenticate import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -34,7 +35,7 @@ type Payload struct { ExpiresIn int `json:"expires_in,omitempty"` } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) < 3 { return disallowedcommand.Error @@ -49,12 +50,12 @@ func (c *Command) Execute() error { return err } - accessResponse, err := c.verifyAccess(action, repo) + accessResponse, err := c.verifyAccess(ctx, action, repo) if err != nil { return err } - payload, err := c.authenticate(operation, repo, accessResponse.UserId) + payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { // return nothing just like Ruby's GitlabShell#lfs_authenticate does return nil @@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) { return action, nil } -func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(action, repo) + return cmd.Verify(ctx, action, repo) } -func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) { +func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) { client, err := lfsauthenticate.NewClient(c.Config, c.Args) if err != nil { return nil, err } - response, err := client.Authenticate(operation, repo, userId) + response, err := client.Authenticate(ctx, operation, repo, userId) if err != nil { return nil, err } diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index a1c7aec..55998ab 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -2,6 +2,7 @@ package lfsauthenticate import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Error(t, err) require.Equal(t, tc.expectedOutput, err.Error()) @@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go index b283890..6f3d03e 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken.go +++ b/internal/command/personalaccesstoken/personalaccesstoken.go @@ -1,6 +1,7 @@ package personalaccesstoken import ( + "context" "errors" "fmt" "strconv" @@ -31,13 +32,13 @@ type tokenArgs struct { ExpiresDate string // Calculated, a TTL is passed from command-line. } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { err := c.parseTokenArgs() if err != nil { return err } - response, err := c.getPersonalAccessToken() + response, err := c.getPersonalAccessToken(ctx) if err != nil { return err } @@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error { return nil } -func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) { +func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) { client, err := personalaccesstoken.NewClient(c.Config) if err != nil { return nil, err } - return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate) + return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate) } diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go index bc748ab..5970142 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken_test.go +++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go @@ -2,6 +2,7 @@ package personalaccesstoken import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -170,7 +171,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) if tc.expectedError == "" { assert.NoError(t, err) diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go index 8bee484..2a0c146 100644 --- a/internal/command/receivepack/gitalycall_test.go +++ b/internal/command/receivepack/gitalycall_test.go @@ -2,6 +2,7 @@ package receivepack import ( "bytes" + "context" "testing" "github.com/sirupsen/logrus" @@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) { hook := testhelper.SetupLogger() - err = cmd.Execute() + err = cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String()) diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index 7271264..4d5c686 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -1,6 +1,8 @@ package receivepack import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -15,14 +17,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -33,14 +35,14 @@ func (c *Command) Execute() error { ReadWriter: c.ReadWriter, EOFSent: true, } - return customAction.Execute(response) + return customAction.Execute(ctx, response) } return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index a4632b4..44cb680 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -2,6 +2,7 @@ package receivepack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) { cmd, _, cleanup := setup(t, "disallowed", requests) defer cleanup() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } @@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) { cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t)) defer cleanup() - require.NoError(t, cmd.Execute()) + require.NoError(t, cmd.Execute(context.Background())) require.Equal(t, "customoutput", output.String()) } diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go index 5d2d709..9fcdde4 100644 --- a/internal/command/shared/accessverifier/accessverifier.go +++ b/internal/command/shared/accessverifier/accessverifier.go @@ -1,6 +1,7 @@ package accessverifier import ( + "context" "errors" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -18,13 +19,13 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) { +func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) { client, err := accessverifier.NewClient(c.Config) if err != nil { return nil, err } - response, err := client.Verify(c.Args, action, repo) + response, err := client.Verify(ctx, c.Args, action, repo) if err != nil { return nil, err } diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go index 998e622..8ad87b8 100644 --- a/internal/command/shared/accessverifier/accessverifier_test.go +++ b/internal/command/shared/accessverifier/accessverifier_test.go @@ -2,6 +2,7 @@ package accessverifier import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) { defer cleanup() cmd.Args = &commandargs.Shell{GitlabKeyId: "2"} - _, err := cmd.Verify(action, repo) + _, err := cmd.Verify(context.Background(), action, repo) require.Equal(t, "missing user", err.Error()) } @@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) { defer cleanup() cmd.Args = &commandargs.Shell{GitlabKeyId: "1"} - cmd.Verify(action, repo) + cmd.Verify(context.Background(), action, repo) require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String()) require.Empty(t, outBuf.String()) diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go index 2ba1091..0675d36 100644 --- a/internal/command/shared/customaction/customaction.go +++ b/internal/command/shared/customaction/customaction.go @@ -2,6 +2,7 @@ package customaction import ( "bytes" + "context" "errors" "gitlab.com/gitlab-org/gitlab-shell/client" @@ -34,7 +35,7 @@ type Command struct { EOFSent bool } -func (c *Command) Execute(response *accessverifier.Response) error { +func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error { data := response.Payload.Data apiEndpoints := data.ApiEndpoints @@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error { return errors.New("Custom action error: Empty API endpoints") } - return c.processApiEndpoints(response) + return c.processApiEndpoints(ctx, response) } -func (c *Command) processApiEndpoints(response *accessverifier.Response) error { +func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error { client, err := gitlabnet.GetClient(c.Config) if err != nil { @@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error { log.WithFields(fields).Info("Performing custom action") - response, err := c.performRequest(client, endpoint, request) + response, err := c.performRequest(ctx, client, endpoint, request) if err != nil { return err } @@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error { return nil } -func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) { - response, err := client.DoRequest(http.MethodPost, endpoint, request) +func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) { + response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request) if err != nil { return nil, err } diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go index 46c5f32..119da5b 100644 --- a/internal/command/shared/customaction/customaction_test.go +++ b/internal/command/shared/customaction/customaction_test.go @@ -2,6 +2,7 @@ package customaction import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) { EOFSent: true, } - require.NoError(t, cmd.Execute(response)) + require.NoError(t, cmd.Execute(context.Background(), response)) // expect printing of info message, "custom" string from the first request // and "output" string from the second request @@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) { EOFSent: false, } - require.NoError(t, cmd.Execute(response)) + require.NoError(t, cmd.Execute(context.Background(), response)) // expect printing of info message, "custom" string from the first request // and "output" string from the second request diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go index 2f13cc5..f0a9e7b 100644 --- a/internal/command/twofactorrecover/twofactorrecover.go +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -1,6 +1,7 @@ package twofactorrecover import ( + "context" "fmt" "strings" @@ -16,9 +17,9 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { if c.canContinue() { - c.displayRecoveryCodes() + c.displayRecoveryCodes(ctx) } else { fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") } @@ -38,8 +39,8 @@ func (c *Command) canContinue() bool { return answer == "yes" } -func (c *Command) displayRecoveryCodes() { - codes, err := c.getRecoveryCodes() +func (c *Command) displayRecoveryCodes(ctx context.Context) { + codes, err := c.getRecoveryCodes(ctx) if err == nil { messageWithCodes := @@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() { } } -func (c *Command) getRecoveryCodes() ([]string, error) { +func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) { client, err := twofactorrecover.NewClient(c.Config) if err != nil { return nil, err } - return client.GetRecoveryCodes(c.Args) + return client.GetRecoveryCodes(ctx, c.Args) } diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go index d2f931b..ea6abd6 100644 --- a/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/internal/command/twofactorrecover/twofactorrecover_test.go @@ -2,6 +2,7 @@ package twofactorrecover import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -127,7 +128,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go index eaeb2b7..f74093a 100644 --- a/internal/command/uploadarchive/gitalycall_test.go +++ b/internal/command/uploadarchive/gitalycall_test.go @@ -2,6 +2,7 @@ package uploadarchive import ( "bytes" + "context" "testing" "github.com/sirupsen/logrus" @@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) { hook := testhelper.SetupLogger() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "UploadArchive: "+repo, output.String()) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index 9d4fbe0..178b42b 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -1,6 +1,8 @@ package uploadarchive import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -14,14 +16,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -29,8 +31,8 @@ func (c *Command) Execute() error { return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index 7b03009..5426569 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -2,6 +2,7 @@ package uploadarchive import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go index d6762a2..22189b8 100644 --- a/internal/command/uploadpack/gitalycall_test.go +++ b/internal/command/uploadpack/gitalycall_test.go @@ -2,6 +2,7 @@ package uploadpack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) { hook := testhelper.SetupLogger() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "UploadPack: "+repo, output.String()) diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 56814d7..fca3823 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -1,6 +1,8 @@ package uploadpack import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -15,14 +17,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -33,14 +35,14 @@ func (c *Command) Execute() error { ReadWriter: c.ReadWriter, EOFSent: false, } - return customAction.Execute(response) + return customAction.Execute(ctx, response) } return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index 7ea8e5d..20edb57 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -2,6 +2,7 @@ package uploadpack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } |