summaryrefslogtreecommitdiff
path: root/internal/command
diff options
context:
space:
mode:
authorStan Hu <stanhu@gmail.com>2020-09-19 03:34:49 -0700
committerStan Hu <stanhu@gmail.com>2020-09-19 14:00:45 -0700
commit0590d9198f653ff2170e0f26790056bef4f056fe (patch)
treedc0d68866ea16ba4f74d441c3aa2048b12fb9e95 /internal/command
parentf100e7e83943b3bb5db232f5bf79a616fdba88f1 (diff)
downloadgitlab-shell-sh-extract-context-from-env.tar.gz
Make it possible to propagate correlation ID across processessh-extract-context-from-env
Previously, gitlab-shell did not pass a context through the application. Correlation IDs were generated down the call stack, since we don't pass the context around from the start execution. This has several potential downsides: 1. It's easier for programming mistakes to be made in future which lead to multiple correlation IDs being generated for a single request. 2. Correlation IDs cannot be passed in from upstream requests 3. Other advantages of context passing, such as distributed tracing is not possible. This commit changes the behavior: 1. Extract the correlation ID from the environment at the start of the application. 2. If no correlation ID exists, generate a random one. 3. Pass the correlation ID to the GitLabNet API requests. This change also enables other clients of GitLabNet (e.g. Gitaly) to pass along the correlation ID in the internal API requests (https://gitlab.com/gitlab-org/gitaly/-/issues/2725). Fixes https://gitlab.com/gitlab-org/gitlab-shell/-/issues/474
Diffstat (limited to 'internal/command')
-rw-r--r--internal/command/authorizedkeys/authorized_keys.go13
-rw-r--r--internal/command/authorizedkeys/authorized_keys_test.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals.go3
-rw-r--r--internal/command/authorizedprincipals/authorized_principals_test.go3
-rw-r--r--internal/command/command.go28
-rw-r--r--internal/command/command_test.go66
-rw-r--r--internal/command/discover/discover.go9
-rw-r--r--internal/command/discover/discover_test.go5
-rw-r--r--internal/command/healthcheck/healthcheck.go9
-rw-r--r--internal/command/healthcheck/healthcheck_test.go7
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate.go15
-rw-r--r--internal/command/lfsauthenticate/lfsauthenticate_test.go5
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken.go9
-rw-r--r--internal/command/personalaccesstoken/personalaccesstoken_test.go3
-rw-r--r--internal/command/receivepack/gitalycall_test.go3
-rw-r--r--internal/command/receivepack/receivepack.go12
-rw-r--r--internal/command/receivepack/receivepack_test.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier.go5
-rw-r--r--internal/command/shared/accessverifier/accessverifier_test.go5
-rw-r--r--internal/command/shared/customaction/customaction.go13
-rw-r--r--internal/command/shared/customaction/customaction_test.go5
-rw-r--r--internal/command/twofactorrecover/twofactorrecover.go13
-rw-r--r--internal/command/twofactorrecover/twofactorrecover_test.go3
-rw-r--r--internal/command/uploadarchive/gitalycall_test.go3
-rw-r--r--internal/command/uploadarchive/uploadarchive.go10
-rw-r--r--internal/command/uploadarchive/uploadarchive_test.go3
-rw-r--r--internal/command/uploadpack/gitalycall_test.go3
-rw-r--r--internal/command/uploadpack/uploadpack.go12
-rw-r--r--internal/command/uploadpack/uploadpack_test.go3
29 files changed, 199 insertions, 77 deletions
diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go
index 7554761..736aeed 100644
--- a/internal/command/authorizedkeys/authorized_keys.go
+++ b/internal/command/authorizedkeys/authorized_keys.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"strconv"
@@ -17,7 +18,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
@@ -27,15 +28,15 @@ func (c *Command) Execute() error {
return nil
}
- if err := c.printKeyLine(); err != nil {
+ if err := c.printKeyLine(ctx); err != nil {
return err
}
return nil
}
-func (c *Command) printKeyLine() error {
- response, err := c.getAuthorizedKey()
+func (c *Command) printKeyLine(ctx context.Context) error {
+ response, err := c.getAuthorizedKey(ctx)
if err != nil {
fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
return nil
@@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error {
return nil
}
-func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
+func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) {
client, err := authorizedkeys.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByKey(c.Args.Key)
+ return client.GetByKey(ctx, c.Args.Key)
}
diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go
index e12f4fa..f15c34d 100644
--- a/internal/command/authorizedkeys/authorized_keys_test.go
+++ b/internal/command/authorizedkeys/authorized_keys_test.go
@@ -2,6 +2,7 @@ package authorizedkeys
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -97,7 +98,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go
index ab5f2f8..44f6c47 100644
--- a/internal/command/authorizedprincipals/authorized_principals.go
+++ b/internal/command/authorizedprincipals/authorized_principals.go
@@ -1,6 +1,7 @@
package authorizedprincipals
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,7 +16,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if err := c.printPrincipalLines(); err != nil {
return err
}
diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go
index f11dd0f..ec97b65 100644
--- a/internal/command/authorizedprincipals/authorized_principals_test.go
+++ b/internal/command/authorizedprincipals/authorized_principals_test.go
@@ -2,6 +2,7 @@ package authorizedprincipals
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -54,7 +55,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
diff --git a/internal/command/command.go b/internal/command/command.go
index 283b4a1..7e0617e 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -1,6 +1,8 @@
package command
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -16,10 +18,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/tracing"
)
type Command interface {
- Execute() error
+ Execute(ctx context.Context) error
}
func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
@@ -35,6 +40,27 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re
return nil, disallowedcommand.Error
}
+// ContextWithCorrelationID() will always return a background Context
+// with a correlation ID. It will first attempt to extract the ID from
+// an environment variable. If is not available, a random one will be
+// generated.
+func ContextWithCorrelationID() (context.Context, func()) {
+ ctx, finished := tracing.ExtractFromEnv(context.Background())
+ defer finished()
+
+ correlationID := correlation.ExtractFromContext(ctx)
+ if correlationID == "" {
+ correlationID, err := correlation.RandomID()
+ if err != nil {
+ log.WithError(err).Warn("unable to generate correlation ID")
+ } else {
+ ctx = correlation.ContextWithCorrelation(ctx, correlationID)
+ }
+ }
+
+ return ctx, finished
+}
+
func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch e.Name {
case executable.GitlabShell:
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index db55e7d..cc2364f 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -2,6 +2,7 @@ package command
import (
"errors"
+ "os"
"testing"
"github.com/stretchr/testify/require"
@@ -20,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+ "gitlab.com/gitlab-org/labkit/correlation"
)
var (
@@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) {
})
}
}
+
+func TestContextWithCorrelationID(t *testing.T) {
+ testCases := []struct {
+ name string
+ additionalEnv map[string]string
+ expectedCorrelationID string
+ }{
+ {
+ name: "no CORRELATION_ID in environment",
+ },
+ {
+ name: "CORRELATION_ID in envioonment",
+ additionalEnv: map[string]string{
+ "CORRELATION_ID": "abc123",
+ },
+ expectedCorrelationID: "abc123",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ resetEnvironment := addAdditionalEnv(tc.additionalEnv)
+ defer resetEnvironment()
+
+ ctx, finished := ContextWithCorrelationID()
+ require.NotNil(t, ctx, "ctx is nil")
+ require.NotNil(t, finished, "finished is nil")
+ correlationID := correlation.ExtractFromContext(ctx)
+ require.NotEmpty(t, correlationID)
+
+ if tc.expectedCorrelationID != "" {
+ require.Equal(t, tc.expectedCorrelationID, correlationID)
+ }
+ defer finished()
+ })
+ }
+}
+
+// addAdditionalEnv will configure additional environment values
+// and return a deferrable function to reset the environment to
+// it's original state after the test
+func addAdditionalEnv(envMap map[string]string) func() {
+ prevValues := map[string]string{}
+ unsetValues := []string{}
+ for k, v := range envMap {
+ value, exists := os.LookupEnv(k)
+ if exists {
+ prevValues[k] = value
+ } else {
+ unsetValues = append(unsetValues, k)
+ }
+ os.Setenv(k, v)
+ }
+
+ return func() {
+ for k, v := range prevValues {
+ os.Setenv(k, v)
+ }
+
+ for _, k := range unsetValues {
+ os.Unsetenv(k)
+ }
+
+ }
+}
diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go
index 3aa7456..822be32 100644
--- a/internal/command/discover/discover.go
+++ b/internal/command/discover/discover.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -15,8 +16,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.getUserInfo()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
@@ -30,11 +31,11 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) getUserInfo() (*discover.Response, error) {
+func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetByCommandArgs(c.Args)
+ return client.GetByCommandArgs(ctx, c.Args)
}
diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go
index 8edbcb9..5431410 100644
--- a/internal/command/discover/discover_test.go
+++ b/internal/command/discover/discover_test.go
@@ -2,6 +2,7 @@ package discover
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -83,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
@@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go
index bbc73dc..b04eb0d 100644
--- a/internal/command/healthcheck/healthcheck.go
+++ b/internal/command/healthcheck/healthcheck.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
@@ -18,8 +19,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
- response, err := c.runCheck()
+func (c *Command) Execute(ctx context.Context) error {
+ response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
@@ -34,13 +35,13 @@ func (c *Command) Execute() error {
return nil
}
-func (c *Command) runCheck() (*healthcheck.Response, error) {
+func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
client, err := healthcheck.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Check()
+ response, err := client.Check(ctx)
if err != nil {
return nil, err
}
diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go
index 7479bcb..d05e563 100644
--- a/internal/command/healthcheck/healthcheck_test.go
+++ b/internal/command/healthcheck/healthcheck_test.go
@@ -2,6 +2,7 @@ package healthcheck
import (
"bytes"
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -53,7 +54,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
@@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
@@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go
index 2aaac2a..dab69ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -34,7 +35,7 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
@@ -49,12 +50,12 @@ func (c *Command) Execute() error {
return err
}
- accessResponse, err := c.verifyAccess(action, repo)
+ accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
}
- payload, err := c.authenticate(operation, repo, accessResponse.UserId)
+ payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
return nil
@@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) {
return action, nil
}
-func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(action, repo)
+ return cmd.Verify(ctx, action, repo)
}
-func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) {
+func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) {
client, err := lfsauthenticate.NewClient(c.Config, c.Args)
if err != nil {
return nil, err
}
- response, err := client.Authenticate(operation, repo, userId)
+ response, err := client.Authenticate(ctx, operation, repo, userId)
if err != nil {
return nil, err
}
diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go
index a1c7aec..55998ab 100644
--- a/internal/command/lfsauthenticate/lfsauthenticate_test.go
+++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go
@@ -2,6 +2,7 @@ package lfsauthenticate
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -146,7 +147,7 @@ func TestLfsAuthenticateRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go
index b283890..6f3d03e 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"strconv"
@@ -31,13 +32,13 @@ type tokenArgs struct {
ExpiresDate string // Calculated, a TTL is passed from command-line.
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
err := c.parseTokenArgs()
if err != nil {
return err
}
- response, err := c.getPersonalAccessToken()
+ response, err := c.getPersonalAccessToken(ctx)
if err != nil {
return err
}
@@ -76,11 +77,11 @@ func (c *Command) parseTokenArgs() error {
return nil
}
-func (c *Command) getPersonalAccessToken() (*personalaccesstoken.Response, error) {
+func (c *Command) getPersonalAccessToken(ctx context.Context) (*personalaccesstoken.Response, error) {
client, err := personalaccesstoken.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetPersonalAccessToken(c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
+ return client.GetPersonalAccessToken(ctx, c.Args, c.TokenArgs.Name, &c.TokenArgs.Scopes, c.TokenArgs.ExpiresDate)
}
diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go
index bc748ab..5970142 100644
--- a/internal/command/personalaccesstoken/personalaccesstoken_test.go
+++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go
@@ -2,6 +2,7 @@ package personalaccesstoken
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -170,7 +171,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
if tc.expectedError == "" {
assert.NoError(t, err)
diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go
index 8bee484..2a0c146 100644
--- a/internal/command/receivepack/gitalycall_test.go
+++ b/internal/command/receivepack/gitalycall_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -42,7 +43,7 @@ func TestReceivePack(t *testing.T) {
hook := testhelper.SetupLogger()
- err = cmd.Execute()
+ err = cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "ReceivePack: "+userId+" "+repo, output.String())
diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go
index 7271264..4d5c686 100644
--- a/internal/command/receivepack/receivepack.go
+++ b/internal/command/receivepack/receivepack.go
@@ -1,6 +1,8 @@
package receivepack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: true,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go
index a4632b4..44cb680 100644
--- a/internal/command/receivepack/receivepack_test.go
+++ b/internal/command/receivepack/receivepack_test.go
@@ -2,6 +2,7 @@ package receivepack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -18,7 +19,7 @@ func TestForbiddenAccess(t *testing.T) {
cmd, _, cleanup := setup(t, "disallowed", requests)
defer cleanup()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
@@ -26,7 +27,7 @@ func TestCustomReceivePack(t *testing.T) {
cmd, output, cleanup := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t))
defer cleanup()
- require.NoError(t, cmd.Execute())
+ require.NoError(t, cmd.Execute(context.Background()))
require.Equal(t, "customoutput", output.String())
}
diff --git a/internal/command/shared/accessverifier/accessverifier.go b/internal/command/shared/accessverifier/accessverifier.go
index 5d2d709..9fcdde4 100644
--- a/internal/command/shared/accessverifier/accessverifier.go
+++ b/internal/command/shared/accessverifier/accessverifier.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
@@ -18,13 +19,13 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Verify(action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Command) Verify(ctx context.Context, action commandargs.CommandType, repo string) (*Response, error) {
client, err := accessverifier.NewClient(c.Config)
if err != nil {
return nil, err
}
- response, err := client.Verify(c.Args, action, repo)
+ response, err := client.Verify(ctx, c.Args, action, repo)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/accessverifier/accessverifier_test.go b/internal/command/shared/accessverifier/accessverifier_test.go
index 998e622..8ad87b8 100644
--- a/internal/command/shared/accessverifier/accessverifier_test.go
+++ b/internal/command/shared/accessverifier/accessverifier_test.go
@@ -2,6 +2,7 @@ package accessverifier
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -65,7 +66,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "2"}
- _, err := cmd.Verify(action, repo)
+ _, err := cmd.Verify(context.Background(), action, repo)
require.Equal(t, "missing user", err.Error())
}
@@ -75,7 +76,7 @@ func TestConsoleMessages(t *testing.T) {
defer cleanup()
cmd.Args = &commandargs.Shell{GitlabKeyId: "1"}
- cmd.Verify(action, repo)
+ cmd.Verify(context.Background(), action, repo)
require.Equal(t, "remote: \nremote: console\nremote: message\nremote: \n", errBuf.String())
require.Empty(t, outBuf.String())
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 2ba1091..0675d36 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"errors"
"gitlab.com/gitlab-org/gitlab-shell/client"
@@ -34,7 +35,7 @@ type Command struct {
EOFSent bool
}
-func (c *Command) Execute(response *accessverifier.Response) error {
+func (c *Command) Execute(ctx context.Context, response *accessverifier.Response) error {
data := response.Payload.Data
apiEndpoints := data.ApiEndpoints
@@ -42,10 +43,10 @@ func (c *Command) Execute(response *accessverifier.Response) error {
return errors.New("Custom action error: Empty API endpoints")
}
- return c.processApiEndpoints(response)
+ return c.processApiEndpoints(ctx, response)
}
-func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
+func (c *Command) processApiEndpoints(ctx context.Context, response *accessverifier.Response) error {
client, err := gitlabnet.GetClient(c.Config)
if err != nil {
@@ -64,7 +65,7 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
log.WithFields(fields).Info("Performing custom action")
- response, err := c.performRequest(client, endpoint, request)
+ response, err := c.performRequest(ctx, client, endpoint, request)
if err != nil {
return err
}
@@ -95,8 +96,8 @@ func (c *Command) processApiEndpoints(response *accessverifier.Response) error {
return nil
}
-func (c *Command) performRequest(client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
- response, err := client.DoRequest(http.MethodPost, endpoint, request)
+func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetClient, endpoint string, request *Request) (*Response, error) {
+ response, err := client.DoRequest(ctx, http.MethodPost, endpoint, request)
if err != nil {
return nil, err
}
diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go
index 46c5f32..119da5b 100644
--- a/internal/command/shared/customaction/customaction_test.go
+++ b/internal/command/shared/customaction/customaction_test.go
@@ -2,6 +2,7 @@ package customaction
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -78,7 +79,7 @@ func TestExecuteEOFSent(t *testing.T) {
EOFSent: true,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
@@ -148,7 +149,7 @@ func TestExecuteNoEOFSent(t *testing.T) {
EOFSent: false,
}
- require.NoError(t, cmd.Execute(response))
+ require.NoError(t, cmd.Execute(context.Background(), response))
// expect printing of info message, "custom" string from the first request
// and "output" string from the second request
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go
index 2f13cc5..f0a9e7b 100644
--- a/internal/command/twofactorrecover/twofactorrecover.go
+++ b/internal/command/twofactorrecover/twofactorrecover.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"fmt"
"strings"
@@ -16,9 +17,9 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
if c.canContinue() {
- c.displayRecoveryCodes()
+ c.displayRecoveryCodes(ctx)
} else {
fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
}
@@ -38,8 +39,8 @@ func (c *Command) canContinue() bool {
return answer == "yes"
}
-func (c *Command) displayRecoveryCodes() {
- codes, err := c.getRecoveryCodes()
+func (c *Command) displayRecoveryCodes(ctx context.Context) {
+ codes, err := c.getRecoveryCodes(ctx)
if err == nil {
messageWithCodes :=
@@ -54,12 +55,12 @@ func (c *Command) displayRecoveryCodes() {
}
}
-func (c *Command) getRecoveryCodes() ([]string, error) {
+func (c *Command) getRecoveryCodes(ctx context.Context) ([]string, error) {
client, err := twofactorrecover.NewClient(c.Config)
if err != nil {
return nil, err
}
- return client.GetRecoveryCodes(c.Args)
+ return client.GetRecoveryCodes(ctx, c.Args)
}
diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go
index d2f931b..ea6abd6 100644
--- a/internal/command/twofactorrecover/twofactorrecover_test.go
+++ b/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -2,6 +2,7 @@ package twofactorrecover
import (
"bytes"
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -127,7 +128,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, output.String())
diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go
index eaeb2b7..f74093a 100644
--- a/internal/command/uploadarchive/gitalycall_test.go
+++ b/internal/command/uploadarchive/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/sirupsen/logrus"
@@ -38,7 +39,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadArchive: "+repo, output.String())
diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go
index 9d4fbe0..178b42b 100644
--- a/internal/command/uploadarchive/uploadarchive.go
+++ b/internal/command/uploadarchive/uploadarchive.go
@@ -1,6 +1,8 @@
package uploadarchive
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -14,14 +16,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -29,8 +31,8 @@ func (c *Command) Execute() error {
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go
index 7b03009..5426569 100644
--- a/internal/command/uploadarchive/uploadarchive_test.go
+++ b/internal/command/uploadarchive/uploadarchive_test.go
@@ -2,6 +2,7 @@ package uploadarchive
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go
index d6762a2..22189b8 100644
--- a/internal/command/uploadpack/gitalycall_test.go
+++ b/internal/command/uploadpack/gitalycall_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -37,7 +38,7 @@ func TestUploadPack(t *testing.T) {
hook := testhelper.SetupLogger()
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "UploadPack: "+repo, output.String())
diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go
index 56814d7..fca3823 100644
--- a/internal/command/uploadpack/uploadpack.go
+++ b/internal/command/uploadpack/uploadpack.go
@@ -1,6 +1,8 @@
package uploadpack
import (
+ "context"
+
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/shared/accessverifier"
@@ -15,14 +17,14 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
-func (c *Command) Execute() error {
+func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
}
repo := args[1]
- response, err := c.verifyAccess(repo)
+ response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
}
@@ -33,14 +35,14 @@ func (c *Command) Execute() error {
ReadWriter: c.ReadWriter,
EOFSent: false,
}
- return customAction.Execute(response)
+ return customAction.Execute(ctx, response)
}
return c.performGitalyCall(response)
}
-func (c *Command) verifyAccess(repo string) (*accessverifier.Response, error) {
+func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
- return cmd.Verify(c.Args.CommandType, repo)
+ return cmd.Verify(ctx, c.Args.CommandType, repo)
}
diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go
index 7ea8e5d..20edb57 100644
--- a/internal/command/uploadpack/uploadpack_test.go
+++ b/internal/command/uploadpack/uploadpack_test.go
@@ -2,6 +2,7 @@ package uploadpack
import (
"bytes"
+ "context"
"testing"
"github.com/stretchr/testify/require"
@@ -26,6 +27,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
- err := cmd.Execute()
+ err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}