diff options
53 files changed, 303 insertions, 157 deletions
diff --git a/client/client_test.go b/client/client_test.go index e92093a..e0650b2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -78,7 +79,7 @@ func TestClients(t *testing.T) { func testSuccessfulGet(t *testing.T, client *GitlabNetClient) { t.Run("Successful get", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Get("/hello") + response, err := client.Get(context.Background(), "/hello") require.NoError(t, err) require.NotNil(t, response) @@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) { hook := testhelper.SetupLogger() data := map[string]string{"key": "value"} - response, err := client.Post("/post_endpoint", data) + response, err := client.Post(context.Background(), "/post_endpoint", data) require.NoError(t, err) require.NotNil(t, response) @@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) { func testMissing(t *testing.T, client *GitlabNetClient) { t.Run("Missing error for GET", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Get("/missing") + response, err := client.Get(context.Background(), "/missing") assert.EqualError(t, err, "Internal API error (404)") assert.Nil(t, response) @@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) { t.Run("Missing error for POST", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Post("/missing", map[string]string{}) + response, err := client.Post(context.Background(), "/missing", map[string]string{}) assert.EqualError(t, err, "Internal API error (404)") assert.Nil(t, response) @@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) { func testErrorMessage(t *testing.T, client *GitlabNetClient) { t.Run("Error with message for GET", func(t *testing.T) { - response, err := client.Get("/error") + response, err := client.Get(context.Background(), "/error") assert.EqualError(t, err, "Don't do that") assert.Nil(t, response) }) t.Run("Error with message for POST", func(t *testing.T) { - response, err := client.Post("/error", map[string]string{}) + response, err := client.Post(context.Background(), "/error", map[string]string{}) assert.EqualError(t, err, "Don't do that") assert.Nil(t, response) }) @@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) { t.Run("Broken request for GET", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Get("/broken") + response, err := client.Get(context.Background(), "/broken") assert.EqualError(t, err, "Internal API unreachable") assert.Nil(t, response) @@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) { t.Run("Broken request for POST", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Post("/broken", map[string]string{}) + response, err := client.Post(context.Background(), "/broken", map[string]string{}) assert.EqualError(t, err, "Internal API unreachable") assert.Nil(t, response) @@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) { func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) { t.Run("Authentication headers for GET", func(t *testing.T) { - response, err := client.Get("/auth") + response, err := client.Get(context.Background(), "/auth") require.NoError(t, err) require.NotNil(t, response) @@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) { }) t.Run("Authentication headers for POST", func(t *testing.T) { - response, err := client.Post("/auth", map[string]string{}) + response, err := client.Post(context.Background(), "/auth", map[string]string{}) require.NoError(t, err) require.NotNil(t, response) diff --git a/client/gitlabnet.go b/client/gitlabnet.go index 0657ca0..b908d04 100644 --- a/client/gitlabnet.go +++ b/client/gitlabnet.go @@ -11,8 +11,9 @@ import ( "strings" "time" - log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/correlation" + + log "github.com/sirupsen/logrus" ) const ( @@ -59,7 +60,7 @@ func normalizePath(path string) string { return path } -func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) { +func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) { var jsonReader io.Reader if data != nil { jsonData, err := json.Marshal(data) @@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str jsonReader = bytes.NewReader(jsonData) } - correlationID, err := correlation.RandomID() - ctx := context.Background() - - if err != nil { - log.WithError(err).Warn("unable to generate correlation ID") - } else { - ctx = correlation.ContextWithCorrelation(ctx, correlationID) - } - request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader) if err != nil { return nil, "", err } + correlationID := correlation.ExtractFromContext(ctx) + return request, correlationID, nil } @@ -102,16 +96,16 @@ func parseError(resp *http.Response) error { } -func (c *GitlabNetClient) Get(path string) (*http.Response, error) { - return c.DoRequest(http.MethodGet, normalizePath(path), nil) +func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) { + return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil) } -func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) { - return c.DoRequest(http.MethodPost, normalizePath(path), data) +func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) { + return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data) } -func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) { - request, correlationID, err := newRequest(method, c.httpClient.Host, path, data) +func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) { + request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data) if err != nil { return nil, err } diff --git a/client/httpclient_test.go b/client/httpclient_test.go index fce0cd5..97e1384 100644 --- a/client/httpclient_test.go +++ b/client/httpclient_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/base64" "fmt" "io/ioutil" @@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) { client, cleanup := setup(t, username, password, requests) defer cleanup() - response, err := client.Get("/get_endpoint") + response, err := client.Get(context.Background(), "/get_endpoint") require.NoError(t, err) testBasicAuthHeaders(t, response) - response, err = client.Post("/post_endpoint", nil) + response, err = client.Post(context.Background(), "/post_endpoint", nil) require.NoError(t, err) testBasicAuthHeaders(t, response) } @@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) { client, cleanup := setup(t, "", "", requests) defer cleanup() - _, err := client.Get("/empty_basic_auth") + _, err := client.Get(context.Background(), "/empty_basic_auth") require.NoError(t, err) } diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go index 1c7435f..0cf77e3 100644 --- a/client/httpsclient_test.go +++ b/client/httpsclient_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "io/ioutil" "net/http" @@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) { client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned) defer cleanup() - response, err := client.Get("/hello") + response, err := client.Get(context.Background(), "/hello") require.NoError(t, err) require.NotNil(t, response) @@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) { client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false) defer cleanup() - _, err := client.Get("/hello") + _, err := client.Get(context.Background(), "/hello") require.Error(t, err) assert.Equal(t, err.Error(), "Internal API unreachable") diff --git a/cmd/check/main.go b/cmd/check/main.go index e88b9fe..28634f4 100644 --- a/cmd/check/main.go +++ b/cmd/check/main.go @@ -38,7 +38,10 @@ func main() { os.Exit(1) } - if err = cmd.Execute(); err != nil { + ctx, finished := command.ContextWithCorrelationID() + defer finished() + + if err = cmd.Execute(ctx); err != nil { fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) } diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go index 4b3949c..3a7dcbb 100644 --- a/cmd/gitlab-shell-authorized-keys-check/main.go +++ b/cmd/gitlab-shell-authorized-keys-check/main.go @@ -41,7 +41,10 @@ func main() { os.Exit(1) } - if err = cmd.Execute(); err != nil { + ctx, finished := command.ContextWithCorrelationID() + defer finished() + + if err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go index fc46180..ea8d140 100644 --- a/cmd/gitlab-shell-authorized-principals-check/main.go +++ b/cmd/gitlab-shell-authorized-principals-check/main.go @@ -41,7 +41,10 @@ func main() { os.Exit(1) } - if err = cmd.Execute(); err != nil { + ctx, finished := command.ContextWithCorrelationID() + defer finished() + + if err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go index 8df781c..763aa5e 100644 --- a/cmd/gitlab-shell/main.go +++ b/cmd/gitlab-shell/main.go @@ -41,7 +41,10 @@ func main() { os.Exit(1) } - if err = cmd.Execute(); err != nil { + ctx, finished := command.ContextWithCorrelationID() + defer finished() + + if err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } @@ -248,6 +248,7 @@ github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb6 github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go index 7554761..736aeed 100644 --- a/internal/command/authorizedkeys/authorized_keys.go +++ b/internal/command/authorizedkeys/authorized_keys.go @@ -1,6 +1,7 @@ package authorizedkeys import ( + "context" "fmt" "strconv" @@ -17,7 +18,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { // Do and return nothing when the expected and actual user don't match. // This can happen when the user in sshd_config doesn't match the user // trying to login. When nothing is printed, the user will be denied access. @@ -27,15 +28,15 @@ func (c *Command) Execute() error { return nil } - if err := c.printKeyLine(); err != nil { + if err := c.printKeyLine(ctx); err != nil { return err } return nil } -func (c *Command) printKeyLine() error { - response, err := c.getAuthorizedKey() +func (c *Command) printKeyLine(ctx context.Context) error { + response, err := c.getAuthorizedKey(ctx) if err != nil { fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key)) return nil @@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error { return nil } -func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) { +func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) { client, err := authorizedkeys.NewClient(c.Config) if err != nil { return nil, err } - return client.GetByKey(c.Args.Key) + return client.GetByKey(ctx, c.Args.Key) } diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go index e12f4fa..f15c34d 100644 --- a/internal/command/authorizedkeys/authorized_keys_test.go +++ b/internal/command/authorizedkeys/authorized_keys_test.go @@ -2,6 +2,7 @@ package authorizedkeys import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -97,7 +98,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go index ab5f2f8..44f6c47 100644 --- a/internal/command/authorizedprincipals/authorized_principals.go +++ b/internal/command/authorizedprincipals/authorized_principals.go @@ -1,6 +1,7 @@ package authorizedprincipals import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -15,7 +16,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { if err := c.printPrincipalLines(); err != nil { return err } diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go index f11dd0f..ec97b65 100644 --- a/internal/command/authorizedprincipals/authorized_principals_test.go +++ b/internal/command/authorizedprincipals/authorized_principals_test.go @@ -2,6 +2,7 @@ package authorizedprincipals import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -54,7 +55,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/command.go b/internal/command/command.go index 283b4a1..7e0617e 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -1,6 +1,8 @@ package command import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys" "gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -16,10 +18,13 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/tracing" ) type Command interface { - Execute() error + Execute(ctx context.Context) error } func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) { @@ -35,6 +40,27 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re return nil, disallowedcommand.Error } +// ContextWithCorrelationID() will always return a background Context +// with a correlation ID. It will first attempt to extract the ID from +// an environment variable. If is not available, a random one will be +// generated. +func ContextWithCorrelationID() (context.Context, func()) { + ctx, finished := tracing.ExtractFromEnv(context.Background()) + defer finished() + + correlationID := correlation.ExtractFromContext(ctx) + if correlationID == "" { + correlationID, err := correlation.RandomID() + if err != nil { + log.WithError(err).Warn("unable to generate correlation ID") + } else { + ctx = correlation.ContextWithCorrelation(ctx, correlationID) + } + } + + return ctx, finished +} + func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command { switch e.Name { case executable.GitlabShell: diff --git a/internal/command/command_test.go b/internal/command/command_test.go index db55e7d..cc2364f 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -2,6 +2,7 @@ package command import ( "errors" + "os" "testing" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + "gitlab.com/gitlab-org/labkit/correlation" ) var ( @@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) { }) } } + +func TestContextWithCorrelationID(t *testing.T) { + testCases := []struct { + name string + additionalEnv map[string]string + expectedCorrelationID string + }{ + { + name: "no CORRELATION_ID in environment", + }, + { + name: "CORRELATION_ID in envioonment", + additionalEnv: map[string]string{ + "CORRELATION_ID": "abc123", + }, + expectedCorrelationID: "abc123", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resetEnvironment := addAdditionalEnv(tc.additionalEnv) + defer resetEnvironment() + + ctx, finished := ContextWithCorrelationID() + require.NotNil(t, ctx, "ctx is nil") + require.NotNil(t, finished, "finished is nil") + correlationID := correlation.ExtractFromContext(ctx) + require.NotEmpty(t, correlationID) + + if tc.expectedCorrelationID != "" { + require.Equal(t, tc.expectedCorrelationID, correlationID) + } + defer finished() + }) + } +} + +// addAdditionalEnv will configure additional environment values +// and return a deferrable function to reset the environment to +// it's original state after the test +func addAdditionalEnv(envMap map[string]string) func() { + prevValues := map[string]string{} + unsetValues := []string{} + for k, v := range envMap { + value, exists := os.LookupEnv(k) + if exists { + prevValues[k] = value + } else { + unsetValues = append(unsetValues, k) + } + os.Setenv(k, v) + } + + return func() { + for k, v := range prevValues { + os.Setenv(k, v) + } + + for _, k := range unsetValues { + os.Unsetenv(k) + } + + } +} diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index 3aa7456..822be32 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -1,6 +1,7 @@ package discover import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -15,8 +16,8 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { - response, err := c.getUserInfo() +func (c *Command) Execute(ctx context.Context) error { + response, err := c.getUserInfo(ctx) if err != nil { return fmt.Errorf("Failed to get username: %v", err) } @@ -30,11 +31,11 @@ func (c *Command) Execute() error { return nil } -func (c *Command) getUserInfo() (*discover.Response, error) { +func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { client, err := discover.NewClient(c.Config) if err != nil { return nil, err } - return client.GetByCommandArgs(c.Args) + return client.GetByCommandArgs(ctx, c.Args) } diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index 8edbcb9..5431410 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -2,6 +2,7 @@ package discover import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -83,7 +84,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) @@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, tc.expectedError) diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go index bbc73dc..b04eb0d 100644 --- a/internal/command/healthcheck/healthcheck.go +++ b/internal/command/healthcheck/healthcheck.go @@ -1,6 +1,7 @@ package healthcheck import ( + "context" "fmt" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" @@ -18,8 +19,8 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { - response, err := c.runCheck() +func (c *Command) Execute(ctx context.Context) error { + response, err := c.runCheck(ctx) if err != nil { return fmt.Errorf("%v: FAILED - %v", apiMessage, err) } @@ -34,13 +35,13 @@ func (c *Command) Execute() error { return nil } -func (c *Command) runCheck() (*healthcheck.Response, error) { +func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) { client, err := healthcheck.NewClient(c.Config) if err != nil { return nil, err } - response, err := client.Check() + response, err := client.Check(ctx) if err != nil { return nil, err } diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go index 7479bcb..d05e563 100644 --- a/internal/command/healthcheck/healthcheck_test.go +++ b/internal/command/healthcheck/healthcheck_test.go @@ -2,6 +2,7 @@ package healthcheck import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -53,7 +54,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String()) @@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Error(t, err, "Redis available via internal API: FAILED") require.Equal(t, "Internal API available: OK\n", buffer.String()) } @@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)") } diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index 2aaac2a..dab69ab 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -1,6 +1,7 @@ package lfsauthenticate import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -34,7 +35,7 @@ type Payload struct { ExpiresIn int `json:"expires_in,omitempty"` } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) < 3 { return disallowedcommand.Error @@ -49,12 +50,12 @@ func (c *Command) Execute() error { return err } - accessResponse, err := c.verifyAccess(action, repo) + accessResponse, err := c.verifyAccess(ctx, action, repo) if err != nil { return err } - payload, err := c.authenticate(operation, repo, accessResponse.UserId) + payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { // return nothing just like Ruby's GitlabShell#lfs_authenticate does return nil @@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) { return action, nil } -func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(action, repo) + return cmd.Verify(ctx, action, repo) } -func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) { +func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) { client, err := lfsauthenticate.NewClient(c.Config, c.Args) if err != nil { return nil, err } - response, err := client.Authenticate(operation, repo, userId) + response, err := client.Authenticate(ctx, operation, repo, userId) if err != nil { return nil, err } diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index a1c7aec..55998ab 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -2,6 +2,7 @@ package lfsauthenticate import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Error(t, err) require.Equal(t, tc.expectedOutput, err.Error()) @@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go index b283890..6f3d03e 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken.go +++ b/internal/command/personalaccesstoken/personalaccesstoken.go @@ -1,6 +1,7 @@ package personalaccesstoken import ( + "context" "errors" "fmt" "strconv" @@ -31,13 +32,13 @@ type tokenArgs struct { ExpiresDate string // Calculated, a TTL is passed from command-line. } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { err := c.parseTokenArgs() if err != nil { return err } - response, err := c.getPersonalAccessToken() + response, err := c.getPersonalAccessToken(ctx) if err != nil { return err } @@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error { return nil } -func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) { +func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) { client, err := personalaccesstoken.NewClient(c.Config) if err != nil { return nil, err } - return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate) + return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate) } diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go index bc748ab..5970142 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken_test.go +++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go @@ -2,6 +2,7 @@ package personalaccesstoken import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -170,7 +171,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) if tc.expectedError == "" { assert.NoError(t, err) diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go index 8bee484..2a0c146 100644 --- a/internal/command/receivepack/gitalycall_test.go +++ b/internal/command/receivepack/gitalycall_test.go @@ -2,6 +2,7 @@ package receivepack import ( "bytes" + "context" "testing" "github.com/sirupsen/logrus" @@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) { hook := testhelper.SetupLogger() - err = cmd.Execute() + err = cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String()) diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index 7271264..4d5c686 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -1,6 +1,8 @@ package receivepack import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -15,14 +17,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -33,14 +35,14 @@ func (c *Command) Execute() error { ReadWriter: c.ReadWriter, EOFSent: true, } - return customAction.Execute(response) + return customAction.Execute(ctx, response) } return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index a4632b4..44cb680 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -2,6 +2,7 @@ package receivepack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) { cmd, _, cleanup := setup(t, "disallowed", requests) defer cleanup() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } @@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) { cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t)) defer cleanup() - require.NoError(t, cmd.Execute()) + require.NoError(t, cmd.Execute(context.Background())) require.Equal(t, "customoutput", output.String()) } diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go index 5d2d709..9fcdde4 100644 --- a/internal/command/shared/accessverifier/accessverifier.go +++ b/internal/command/shared/accessverifier/accessverifier.go @@ -1,6 +1,7 @@ package accessverifier import ( + "context" "errors" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" @@ -18,13 +19,13 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) { +func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) { client, err := accessverifier.NewClient(c.Config) if err != nil { return nil, err } - response, err := client.Verify(c.Args, action, repo) + response, err := client.Verify(ctx, c.Args, action, repo) if err != nil { return nil, err } diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go index 998e622..8ad87b8 100644 --- a/internal/command/shared/accessverifier/accessverifier_test.go +++ b/internal/command/shared/accessverifier/accessverifier_test.go @@ -2,6 +2,7 @@ package accessverifier import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) { defer cleanup() cmd.Args = &commandargs.Shell{GitlabKeyId: "2"} - _, err := cmd.Verify(action, repo) + _, err := cmd.Verify(context.Background(), action, repo) require.Equal(t, "missing user", err.Error()) } @@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) { defer cleanup() cmd.Args = &commandargs.Shell{GitlabKeyId: "1"} - cmd.Verify(action, repo) + cmd.Verify(context.Background(), action, repo) require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String()) require.Empty(t, outBuf.String()) diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go index 2ba1091..0675d36 100644 --- a/internal/command/shared/customaction/customaction.go +++ b/internal/command/shared/customaction/customaction.go @@ -2,6 +2,7 @@ package customaction import ( "bytes" + "context" "errors" "gitlab.com/gitlab-org/gitlab-shell/client" @@ -34,7 +35,7 @@ type Command struct { EOFSent bool } -func (c *Command) Execute(response *accessverifier.Response) error { +func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error { data := response.Payload.Data apiEndpoints := data.ApiEndpoints @@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error { return errors.New("Custom action error: Empty API endpoints") } - return c.processApiEndpoints(response) + return c.processApiEndpoints(ctx, response) } -func (c *Command) processApiEndpoints(response *accessverifier.Response) error { +func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error { client, err := gitlabnet.GetClient(c.Config) if err != nil { @@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error { log.WithFields(fields).Info("Performing custom action") - response, err := c.performRequest(client, endpoint, request) + response, err := c.performRequest(ctx, client, endpoint, request) if err != nil { return err } @@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error { return nil } -func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) { - response, err := client.DoRequest(http.MethodPost, endpoint, request) +func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) { + response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request) if err != nil { return nil, err } diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go index 46c5f32..119da5b 100644 --- a/internal/command/shared/customaction/customaction_test.go +++ b/internal/command/shared/customaction/customaction_test.go @@ -2,6 +2,7 @@ package customaction import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) { EOFSent: true, } - require.NoError(t, cmd.Execute(response)) + require.NoError(t, cmd.Execute(context.Background(), response)) // expect printing of info message, "custom" string from the first request // and "output" string from the second request @@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) { EOFSent: false, } - require.NoError(t, cmd.Execute(response)) + require.NoError(t, cmd.Execute(context.Background(), response)) // expect printing of info message, "custom" string from the first request // and "output" string from the second request diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go index 2f13cc5..f0a9e7b 100644 --- a/internal/command/twofactorrecover/twofactorrecover.go +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -1,6 +1,7 @@ package twofactorrecover import ( + "context" "fmt" "strings" @@ -16,9 +17,9 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { if c.canContinue() { - c.displayRecoveryCodes() + c.displayRecoveryCodes(ctx) } else { fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") } @@ -38,8 +39,8 @@ func (c *Command) canContinue() bool { return answer == "yes" } -func (c *Command) displayRecoveryCodes() { - codes, err := c.getRecoveryCodes() +func (c *Command) displayRecoveryCodes(ctx context.Context) { + codes, err := c.getRecoveryCodes(ctx) if err == nil { messageWithCodes := @@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() { } } -func (c *Command) getRecoveryCodes() ([]string, error) { +func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) { client, err := twofactorrecover.NewClient(c.Config) if err != nil { return nil, err } - return client.GetRecoveryCodes(c.Args) + return client.GetRecoveryCodes(ctx, c.Args) } diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go index d2f931b..ea6abd6 100644 --- a/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/internal/command/twofactorrecover/twofactorrecover_test.go @@ -2,6 +2,7 @@ package twofactorrecover import ( "bytes" + "context" "encoding/json" "io/ioutil" "net/http" @@ -127,7 +128,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go index eaeb2b7..f74093a 100644 --- a/internal/command/uploadarchive/gitalycall_test.go +++ b/internal/command/uploadarchive/gitalycall_test.go @@ -2,6 +2,7 @@ package uploadarchive import ( "bytes" + "context" "testing" "github.com/sirupsen/logrus" @@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) { hook := testhelper.SetupLogger() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "UploadArchive: "+repo, output.String()) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index 9d4fbe0..178b42b 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -1,6 +1,8 @@ package uploadarchive import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -14,14 +16,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -29,8 +31,8 @@ func (c *Command) Execute() error { return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index 7b03009..5426569 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -2,6 +2,7 @@ package uploadarchive import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go index d6762a2..22189b8 100644 --- a/internal/command/uploadpack/gitalycall_test.go +++ b/internal/command/uploadpack/gitalycall_test.go @@ -2,6 +2,7 @@ package uploadpack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) { hook := testhelper.SetupLogger() - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "UploadPack: "+repo, output.String()) diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 56814d7..fca3823 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -1,6 +1,8 @@ package uploadpack import ( + "context" + "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier" @@ -15,14 +17,14 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute() error { +func (c *Command) Execute(ctx context.Context) error { args := c.Args.SshArgs if len(args) != 2 { return disallowedcommand.Error } repo := args[1] - response, err := c.verifyAccess(repo) + response, err := c.verifyAccess(ctx, repo) if err != nil { return err } @@ -33,14 +35,14 @@ func (c *Command) Execute() error { ReadWriter: c.ReadWriter, EOFSent: false, } - return customAction.Execute(response) + return customAction.Execute(ctx, response) } return c.performGitalyCall(response) } -func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) { +func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter} - return cmd.Verify(c.Args.CommandType, repo) + return cmd.Verify(ctx, c.Args.CommandType, repo) } diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index 7ea8e5d..20edb57 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -2,6 +2,7 @@ package uploadpack import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute() + err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } diff --git a/internal/config/config.go b/internal/config/config.go index e7abd59..79c2a36 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -37,7 +37,7 @@ type Config struct { Secret string `yaml:"secret"` SslCertDir string `yaml:"ssl_cert_dir"` HttpSettings HttpSettingsConfig `yaml:"http_settings"` - HttpClient *client.HttpClient + HttpClient *client.HttpClient `-` } func (c *Config) GetHttpClient() *client.HttpClient { diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go index 00b9d76..7e120e0 100644 --- a/internal/gitlabnet/accessverifier/client.go +++ b/internal/gitlabnet/accessverifier/client.go @@ -1,6 +1,7 @@ package accessverifier import ( + "context" "fmt" "net/http" @@ -77,7 +78,7 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{client: client}, nil } -func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) { +func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) { request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges} if args.GitlabUsername != "" { @@ -88,7 +89,7 @@ func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, request.CheckIp = sshenv.LocalAddr() - response, err := c.client.Post("/allowed", request) + response, err := c.client.Post(ctx, "/allowed", request) if err != nil { return nil, err } diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go index 7ddbb5e..3681968 100644 --- a/internal/gitlabnet/accessverifier/client_test.go +++ b/internal/gitlabnet/accessverifier/client_test.go @@ -1,6 +1,7 @@ package accessverifier import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -73,7 +74,7 @@ func TestSuccessfulResponses(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - result, err := client.Verify(tc.args, receivePackAction, repo) + result, err := client.Verify(context.Background(), tc.args, receivePackAction, repo) require.NoError(t, err) response := buildExpectedResponse(tc.who) @@ -87,7 +88,7 @@ func TestGeoPushGetCustomAction(t *testing.T) { defer cleanup() args := &commandargs.Shell{GitlabUsername: "custom"} - result, err := client.Verify(args, receivePackAction, repo) + result, err := client.Verify(context.Background(), args, receivePackAction, repo) require.NoError(t, err) response := buildExpectedResponse("user-1") @@ -110,7 +111,7 @@ func TestGeoPullGetCustomAction(t *testing.T) { defer cleanup() args := &commandargs.Shell{GitlabUsername: "custom"} - result, err := client.Verify(args, uploadPackAction, repo) + result, err := client.Verify(context.Background(), args, uploadPackAction, repo) require.NoError(t, err) response := buildExpectedResponse("user-1") @@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { args := &commandargs.Shell{GitlabKeyId: tc.fakeId} - resp, err := client.Verify(args, receivePackAction, repo) + resp, err := client.Verify(context.Background(), args, receivePackAction, repo) require.EqualError(t, err, tc.expectedError) require.Nil(t, resp) diff --git a/internal/gitlabnet/authorizedkeys/client.go b/internal/gitlabnet/authorizedkeys/client.go index e4fec28..0a00034 100644 --- a/internal/gitlabnet/authorizedkeys/client.go +++ b/internal/gitlabnet/authorizedkeys/client.go @@ -1,6 +1,7 @@ package authorizedkeys import ( + "context" "fmt" "net/url" @@ -32,13 +33,13 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } -func (c *Client) GetByKey(key string) (*Response, error) { +func (c *Client) GetByKey(ctx context.Context, key string) (*Response, error) { path, err := pathWithKey(key) if err != nil { return nil, err } - response, err := c.client.Get(path) + response, err := c.client.Get(ctx, path) if err != nil { return nil, err } diff --git a/internal/gitlabnet/authorizedkeys/client_test.go b/internal/gitlabnet/authorizedkeys/client_test.go index c9c76a1..e72840c 100644 --- a/internal/gitlabnet/authorizedkeys/client_test.go +++ b/internal/gitlabnet/authorizedkeys/client_test.go @@ -1,6 +1,7 @@ package authorizedkeys import ( + "context" "encoding/json" "net/http" "testing" @@ -48,7 +49,7 @@ func TestGetByKey(t *testing.T) { client, cleanup := setup(t) defer cleanup() - result, err := client.GetByKey("key") + result, err := client.GetByKey(context.Background(), "key") require.NoError(t, err) require.Equal(t, &Response{Id: 1, Key: "public-key"}, result) } @@ -86,7 +87,7 @@ func TestGetByKeyErrorResponses(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - resp, err := client.GetByKey(tc.key) + resp, err := client.GetByKey(context.Background(), tc.key) require.EqualError(t, err, tc.expectedError) require.Nil(t, resp) diff --git a/internal/gitlabnet/discover/client.go b/internal/gitlabnet/discover/client.go index d1e1906..cc7f516 100644 --- a/internal/gitlabnet/discover/client.go +++ b/internal/gitlabnet/discover/client.go @@ -1,6 +1,7 @@ package discover import ( + "context" "fmt" "net/http" "net/url" @@ -31,7 +32,7 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } -func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) { +func (c *Client) GetByCommandArgs(ctx context.Context, args *commandargs.Shell) (*Response, error) { params := url.Values{} if args.GitlabUsername != "" { params.Add("username", args.GitlabUsername) @@ -43,13 +44,13 @@ func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) { return nil, fmt.Errorf("who='' is invalid") } - return c.getResponse(params) + return c.getResponse(ctx, params) } -func (c *Client) getResponse(params url.Values) (*Response, error) { +func (c *Client) getResponse(ctx context.Context, params url.Values) (*Response, error) { path := "/discover?" + params.Encode() - response, err := c.client.Get(path) + response, err := c.client.Get(ctx, path) if err != nil { return nil, err } diff --git a/internal/gitlabnet/discover/client_test.go b/internal/gitlabnet/discover/client_test.go index 96b3162..cb46dd7 100644 --- a/internal/gitlabnet/discover/client_test.go +++ b/internal/gitlabnet/discover/client_test.go @@ -1,6 +1,7 @@ package discover import ( + "context" "encoding/json" "fmt" "net/http" @@ -62,7 +63,7 @@ func TestGetByKeyId(t *testing.T) { params := url.Values{} params.Add("key_id", "1") - result, err := client.getResponse(params) + result, err := client.getResponse(context.Background(), params) assert.NoError(t, err) assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result) } @@ -73,7 +74,7 @@ func TestGetByUsername(t *testing.T) { params := url.Values{} params.Add("username", "jane-doe") - result, err := client.getResponse(params) + result, err := client.getResponse(context.Background(), params) assert.NoError(t, err) assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result) } @@ -84,7 +85,7 @@ func TestMissingUser(t *testing.T) { params := url.Values{} params.Add("username", "missing") - result, err := client.getResponse(params) + result, err := client.getResponse(context.Background(), params) assert.NoError(t, err) assert.True(t, result.IsAnonymous()) } @@ -119,7 +120,7 @@ func TestErrorResponses(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { params := url.Values{} params.Add("username", tc.fakeUsername) - resp, err := client.getResponse(params) + resp, err := client.getResponse(context.Background(), params) assert.EqualError(t, err, tc.expectedError) assert.Nil(t, resp) diff --git a/internal/gitlabnet/healthcheck/client.go b/internal/gitlabnet/healthcheck/client.go index 09b45af..f148504 100644 --- a/internal/gitlabnet/healthcheck/client.go +++ b/internal/gitlabnet/healthcheck/client.go @@ -1,6 +1,7 @@ package healthcheck import ( + "context" "fmt" "net/http" @@ -34,8 +35,8 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } -func (c *Client) Check() (*Response, error) { - resp, err := c.client.Get(checkPath) +func (c *Client) Check(ctx context.Context) (*Response, error) { + resp, err := c.client.Get(ctx, checkPath) if err != nil { return nil, err } diff --git a/internal/gitlabnet/healthcheck/client_test.go b/internal/gitlabnet/healthcheck/client_test.go index c66ddbd..81ae209 100644 --- a/internal/gitlabnet/healthcheck/client_test.go +++ b/internal/gitlabnet/healthcheck/client_test.go @@ -1,6 +1,7 @@ package healthcheck import ( + "context" "encoding/json" "net/http" "testing" @@ -33,7 +34,7 @@ func TestCheck(t *testing.T) { client, cleanup := setup(t) defer cleanup() - result, err := client.Check() + result, err := client.Check(context.Background()) require.NoError(t, err) require.Equal(t, testResponse, result) } diff --git a/internal/gitlabnet/lfsauthenticate/client.go b/internal/gitlabnet/lfsauthenticate/client.go index fffc225..834cbe1 100644 --- a/internal/gitlabnet/lfsauthenticate/client.go +++ b/internal/gitlabnet/lfsauthenticate/client.go @@ -1,6 +1,7 @@ package lfsauthenticate import ( + "context" "fmt" "net/http" "strings" @@ -40,7 +41,7 @@ func NewClient(config *config.Config, args *commandargs.Shell) (*Client, error) return &Client{config: config, client: client, args: args}, nil } -func (c *Client) Authenticate(operation, repo, userId string) (*Response, error) { +func (c *Client) Authenticate(ctx context.Context, operation, repo, userId string) (*Response, error) { request := &Request{Operation: operation, Repo: repo} if c.args.GitlabKeyId != "" { request.KeyId = c.args.GitlabKeyId @@ -48,7 +49,7 @@ func (c *Client) Authenticate(operation, repo, userId string) (*Response, error) request.UserId = strings.TrimPrefix(userId, "user-") } - response, err := c.client.Post("/lfs_authenticate", request) + response, err := c.client.Post(ctx, "/lfs_authenticate", request) if err != nil { return nil, err } diff --git a/internal/gitlabnet/lfsauthenticate/client_test.go b/internal/gitlabnet/lfsauthenticate/client_test.go index 82e364b..2bd0451 100644 --- a/internal/gitlabnet/lfsauthenticate/client_test.go +++ b/internal/gitlabnet/lfsauthenticate/client_test.go @@ -1,6 +1,7 @@ package lfsauthenticate import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -85,7 +86,7 @@ func TestFailedRequests(t *testing.T) { operation := tc.args.SshArgs[2] - _, err = client.Authenticate(operation, repo, "") + _, err = client.Authenticate(context.Background(), operation, repo, "") require.Error(t, err) require.Equal(t, tc.expectedOutput, err.Error()) @@ -119,7 +120,7 @@ func TestSuccessfulRequests(t *testing.T) { client, err := NewClient(&config.Config{GitlabUrl: url}, args) require.NoError(t, err) - response, err := client.Authenticate(operation, repo, "") + response, err := client.Authenticate(context.Background(), operation, repo, "") require.NoError(t, err) expectedResponse := &Response{ diff --git a/internal/gitlabnet/personalaccesstoken/client.go b/internal/gitlabnet/personalaccesstoken/client.go index 588bead..abbd395 100644 --- a/internal/gitlabnet/personalaccesstoken/client.go +++ b/internal/gitlabnet/personalaccesstoken/client.go @@ -1,6 +1,7 @@ package personalaccesstoken import ( + "context" "errors" "fmt" "net/http" @@ -42,13 +43,13 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } -func (c *Client) GetPersonalAccessToken(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) { - requestBody, err := c.getRequestBody(args, name, scopes, expiresAt) +func (c *Client) GetPersonalAccessToken(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) { + requestBody, err := c.getRequestBody(ctx, args, name, scopes, expiresAt) if err != nil { return nil, err } - response, err := c.client.Post("/personal_access_token", requestBody) + response, err := c.client.Post(ctx, "/personal_access_token", requestBody) if err != nil { return nil, err } @@ -70,7 +71,7 @@ func parse(hr *http.Response) (*Response, error) { return response, nil } -func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) { +func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) { client, err := discover.NewClient(c.config) if err != nil { return nil, err @@ -83,7 +84,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[] return requestBody, nil } - userInfo, err := client.GetByCommandArgs(args) + userInfo, err := client.GetByCommandArgs(ctx, args) if err != nil { return nil, err } diff --git a/internal/gitlabnet/personalaccesstoken/client_test.go b/internal/gitlabnet/personalaccesstoken/client_test.go index de45975..140a7b2 100644 --- a/internal/gitlabnet/personalaccesstoken/client_test.go +++ b/internal/gitlabnet/personalaccesstoken/client_test.go @@ -1,6 +1,7 @@ package personalaccesstoken import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -90,7 +91,7 @@ func TestGetPersonalAccessTokenByKeyId(t *testing.T) { args := &commandargs.Shell{GitlabKeyId: "0"} result, err := client.GetPersonalAccessToken( - args, "newtoken", &[]string{"read_api", "read_repository"}, "", + context.Background(), args, "newtoken", &[]string{"read_api", "read_repository"}, "", ) assert.NoError(t, err) response := &Response{ @@ -109,7 +110,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) { args := &commandargs.Shell{GitlabUsername: "jane-doe"} result, err := client.GetPersonalAccessToken( - args, "newtoken", &[]string{"api"}, "", + context.Background(), args, "newtoken", &[]string{"api"}, "", ) assert.NoError(t, err) response := &Response{true, "YXuxvUgCEmeePY3G1YAa", []string{"api"}, "", ""} @@ -122,7 +123,7 @@ func TestMissingUser(t *testing.T) { args := &commandargs.Shell{GitlabKeyId: "1"} _, err := client.GetPersonalAccessToken( - args, "newtoken", &[]string{"api"}, "", + context.Background(), args, "newtoken", &[]string{"api"}, "", ) assert.Equal(t, "missing user", err.Error()) } @@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { args := &commandargs.Shell{GitlabKeyId: tc.fakeId} resp, err := client.GetPersonalAccessToken( - args, "newtoken", &[]string{"api"}, "", + context.Background(), args, "newtoken", &[]string{"api"}, "", ) assert.EqualError(t, err, tc.expectedError) diff --git a/internal/gitlabnet/twofactorrecover/client.go b/internal/gitlabnet/twofactorrecover/client.go index d22daca..456f892 100644 --- a/internal/gitlabnet/twofactorrecover/client.go +++ b/internal/gitlabnet/twofactorrecover/client.go @@ -1,6 +1,7 @@ package twofactorrecover import ( + "context" "errors" "fmt" "net/http" @@ -37,14 +38,14 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } -func (c *Client) GetRecoveryCodes(args *commandargs.Shell) ([]string, error) { - requestBody, err := c.getRequestBody(args) +func (c *Client) GetRecoveryCodes(ctx context.Context, args *commandargs.Shell) ([]string, error) { + requestBody, err := c.getRequestBody(ctx, args) if err != nil { return nil, err } - response, err := c.client.Post("/two_factor_recovery_codes", requestBody) + response, err := c.client.Post(ctx, "/two_factor_recovery_codes", requestBody) if err != nil { return nil, err } @@ -66,7 +67,7 @@ func parse(hr *http.Response) ([]string, error) { return response.RecoveryCodes, nil } -func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) { +func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell) (*RequestBody, error) { client, err := discover.NewClient(c.config) if err != nil { @@ -77,7 +78,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) { if args.GitlabKeyId != "" { requestBody = &RequestBody{KeyId: args.GitlabKeyId} } else { - userInfo, err := client.GetByCommandArgs(args) + userInfo, err := client.GetByCommandArgs(ctx, args) if err != nil { return nil, err diff --git a/internal/gitlabnet/twofactorrecover/client_test.go b/internal/gitlabnet/twofactorrecover/client_test.go index 372afec..46291aa 100644 --- a/internal/gitlabnet/twofactorrecover/client_test.go +++ b/internal/gitlabnet/twofactorrecover/client_test.go @@ -1,6 +1,7 @@ package twofactorrecover import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -85,7 +86,7 @@ func TestGetRecoveryCodesByKeyId(t *testing.T) { defer cleanup() args := &commandargs.Shell{GitlabKeyId: "0"} - result, err := client.GetRecoveryCodes(args) + result, err := client.GetRecoveryCodes(context.Background(), args) assert.NoError(t, err) assert.Equal(t, []string{"recovery 1", "codes 1"}, result) } @@ -95,7 +96,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) { defer cleanup() args := &commandargs.Shell{GitlabUsername: "jane-doe"} - result, err := client.GetRecoveryCodes(args) + result, err := client.GetRecoveryCodes(context.Background(), args) assert.NoError(t, err) assert.Equal(t, []string{"recovery 2", "codes 2"}, result) } @@ -105,7 +106,7 @@ func TestMissingUser(t *testing.T) { defer cleanup() args := &commandargs.Shell{GitlabKeyId: "1"} - _, err := client.GetRecoveryCodes(args) + _, err := client.GetRecoveryCodes(context.Background(), args) assert.Equal(t, "missing user", err.Error()) } @@ -138,7 +139,7 @@ func TestErrorResponses(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { args := &commandargs.Shell{GitlabKeyId: tc.fakeId} - resp, err := client.GetRecoveryCodes(args) + resp, err := client.GetRecoveryCodes(context.Background(), args) assert.EqualError(t, err, tc.expectedError) assert.Nil(t, resp) |