summaryrefslogtreecommitdiff
path: root/internal/gitlabnet
diff options
context:
space:
mode:
authorStan Hu <stanhu@gmail.com>2020-09-19 03:34:49 -0700
committerStan Hu <stanhu@gmail.com>2020-09-19 14:00:45 -0700
commit0590d9198f653ff2170e0f26790056bef4f056fe (patch)
treedc0d68866ea16ba4f74d441c3aa2048b12fb9e95 /internal/gitlabnet
parentf100e7e83943b3bb5db232f5bf79a616fdba88f1 (diff)
downloadgitlab-shell-sh-extract-context-from-env.tar.gz
Make it possible to propagate correlation ID across processessh-extract-context-from-env
Previously, gitlab-shell did not pass a context through the application. Correlation IDs were generated down the call stack, since we don't pass the context around from the start execution. This has several potential downsides: 1. It's easier for programming mistakes to be made in future which lead to multiple correlation IDs being generated for a single request. 2. Correlation IDs cannot be passed in from upstream requests 3. Other advantages of context passing, such as distributed tracing is not possible. This commit changes the behavior: 1. Extract the correlation ID from the environment at the start of the application. 2. If no correlation ID exists, generate a random one. 3. Pass the correlation ID to the GitLabNet API requests. This change also enables other clients of GitLabNet (e.g. Gitaly) to pass along the correlation ID in the internal API requests (https://gitlab.com/gitlab-org/gitaly/-/issues/2725). Fixes https://gitlab.com/gitlab-org/gitlab-shell/-/issues/474
Diffstat (limited to 'internal/gitlabnet')
-rw-r--r--internal/gitlabnet/accessverifier/client.go5
-rw-r--r--internal/gitlabnet/accessverifier/client_test.go9
-rw-r--r--internal/gitlabnet/authorizedkeys/client.go5
-rw-r--r--internal/gitlabnet/authorizedkeys/client_test.go5
-rw-r--r--internal/gitlabnet/discover/client.go9
-rw-r--r--internal/gitlabnet/discover/client_test.go9
-rw-r--r--internal/gitlabnet/healthcheck/client.go5
-rw-r--r--internal/gitlabnet/healthcheck/client_test.go3
-rw-r--r--internal/gitlabnet/lfsauthenticate/client.go5
-rw-r--r--internal/gitlabnet/lfsauthenticate/client_test.go5
-rw-r--r--internal/gitlabnet/personalaccesstoken/client.go11
-rw-r--r--internal/gitlabnet/personalaccesstoken/client_test.go9
-rw-r--r--internal/gitlabnet/twofactorrecover/client.go11
-rw-r--r--internal/gitlabnet/twofactorrecover/client_test.go9
14 files changed, 57 insertions, 43 deletions
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go
index 00b9d76..7e120e0 100644
--- a/internal/gitlabnet/accessverifier/client.go
+++ b/internal/gitlabnet/accessverifier/client.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"fmt"
"net/http"
@@ -77,7 +78,7 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{client: client}, nil
}
-func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
+func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action commandargs.CommandType, repo string) (*Response, error) {
request := &Request{Action: action, Repo: repo, Protocol: protocol, Changes: anyChanges}
if args.GitlabUsername != "" {
@@ -88,7 +89,7 @@ func (c *Client) Verify(args *commandargs.Shell, action commandargs.CommandType,
request.CheckIp = sshenv.LocalAddr()
- response, err := c.client.Post("/allowed", request)
+ response, err := c.client.Post(ctx, "/allowed", request)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go
index 7ddbb5e..3681968 100644
--- a/internal/gitlabnet/accessverifier/client_test.go
+++ b/internal/gitlabnet/accessverifier/client_test.go
@@ -1,6 +1,7 @@
package accessverifier
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -73,7 +74,7 @@ func TestSuccessfulResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- result, err := client.Verify(tc.args, receivePackAction, repo)
+ result, err := client.Verify(context.Background(), tc.args, receivePackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse(tc.who)
@@ -87,7 +88,7 @@ func TestGeoPushGetCustomAction(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "custom"}
- result, err := client.Verify(args, receivePackAction, repo)
+ result, err := client.Verify(context.Background(), args, receivePackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse("user-1")
@@ -110,7 +111,7 @@ func TestGeoPullGetCustomAction(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "custom"}
- result, err := client.Verify(args, uploadPackAction, repo)
+ result, err := client.Verify(context.Background(), args, uploadPackAction, repo)
require.NoError(t, err)
response := buildExpectedResponse("user-1")
@@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
- resp, err := client.Verify(args, receivePackAction, repo)
+ resp, err := client.Verify(context.Background(), args, receivePackAction, repo)
require.EqualError(t, err, tc.expectedError)
require.Nil(t, resp)
diff --git a/internal/gitlabnet/authorizedkeys/client.go b/internal/gitlabnet/authorizedkeys/client.go
index e4fec28..0a00034 100644
--- a/internal/gitlabnet/authorizedkeys/client.go
+++ b/internal/gitlabnet/authorizedkeys/client.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"fmt"
"net/url"
@@ -32,13 +33,13 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetByKey(key string) (*Response, error) {
+func (c *Client) GetByKey(ctx context.Context, key string) (*Response, error) {
path, err := pathWithKey(key)
if err != nil {
return nil, err
}
- response, err := c.client.Get(path)
+ response, err := c.client.Get(ctx, path)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/authorizedkeys/client_test.go b/internal/gitlabnet/authorizedkeys/client_test.go
index c9c76a1..e72840c 100644
--- a/internal/gitlabnet/authorizedkeys/client_test.go
+++ b/internal/gitlabnet/authorizedkeys/client_test.go
@@ -1,6 +1,7 @@
package authorizedkeys
import (
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -48,7 +49,7 @@ func TestGetByKey(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
- result, err := client.GetByKey("key")
+ result, err := client.GetByKey(context.Background(), "key")
require.NoError(t, err)
require.Equal(t, &Response{Id: 1, Key: "public-key"}, result)
}
@@ -86,7 +87,7 @@ func TestGetByKeyErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- resp, err := client.GetByKey(tc.key)
+ resp, err := client.GetByKey(context.Background(), tc.key)
require.EqualError(t, err, tc.expectedError)
require.Nil(t, resp)
diff --git a/internal/gitlabnet/discover/client.go b/internal/gitlabnet/discover/client.go
index d1e1906..cc7f516 100644
--- a/internal/gitlabnet/discover/client.go
+++ b/internal/gitlabnet/discover/client.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"fmt"
"net/http"
"net/url"
@@ -31,7 +32,7 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
+func (c *Client) GetByCommandArgs(ctx context.Context, args *commandargs.Shell) (*Response, error) {
params := url.Values{}
if args.GitlabUsername != "" {
params.Add("username", args.GitlabUsername)
@@ -43,13 +44,13 @@ func (c *Client) GetByCommandArgs(args *commandargs.Shell) (*Response, error) {
return nil, fmt.Errorf("who='' is invalid")
}
- return c.getResponse(params)
+ return c.getResponse(ctx, params)
}
-func (c *Client) getResponse(params url.Values) (*Response, error) {
+func (c *Client) getResponse(ctx context.Context, params url.Values) (*Response, error) {
path := "/discover?" + params.Encode()
- response, err := c.client.Get(path)
+ response, err := c.client.Get(ctx, path)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/discover/client_test.go b/internal/gitlabnet/discover/client_test.go
index 96b3162..cb46dd7 100644
--- a/internal/gitlabnet/discover/client_test.go
+++ b/internal/gitlabnet/discover/client_test.go
@@ -1,6 +1,7 @@
package discover
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -62,7 +63,7 @@ func TestGetByKeyId(t *testing.T) {
params := url.Values{}
params.Add("key_id", "1")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
}
@@ -73,7 +74,7 @@ func TestGetByUsername(t *testing.T) {
params := url.Values{}
params.Add("username", "jane-doe")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
}
@@ -84,7 +85,7 @@ func TestMissingUser(t *testing.T) {
params := url.Values{}
params.Add("username", "missing")
- result, err := client.getResponse(params)
+ result, err := client.getResponse(context.Background(), params)
assert.NoError(t, err)
assert.True(t, result.IsAnonymous())
}
@@ -119,7 +120,7 @@ func TestErrorResponses(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
params := url.Values{}
params.Add("username", tc.fakeUsername)
- resp, err := client.getResponse(params)
+ resp, err := client.getResponse(context.Background(), params)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)
diff --git a/internal/gitlabnet/healthcheck/client.go b/internal/gitlabnet/healthcheck/client.go
index 09b45af..f148504 100644
--- a/internal/gitlabnet/healthcheck/client.go
+++ b/internal/gitlabnet/healthcheck/client.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"fmt"
"net/http"
@@ -34,8 +35,8 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) Check() (*Response, error) {
- resp, err := c.client.Get(checkPath)
+func (c *Client) Check(ctx context.Context) (*Response, error) {
+ resp, err := c.client.Get(ctx, checkPath)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/healthcheck/client_test.go b/internal/gitlabnet/healthcheck/client_test.go
index c66ddbd..81ae209 100644
--- a/internal/gitlabnet/healthcheck/client_test.go
+++ b/internal/gitlabnet/healthcheck/client_test.go
@@ -1,6 +1,7 @@
package healthcheck
import (
+ "context"
"encoding/json"
"net/http"
"testing"
@@ -33,7 +34,7 @@ func TestCheck(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
- result, err := client.Check()
+ result, err := client.Check(context.Background())
require.NoError(t, err)
require.Equal(t, testResponse, result)
}
diff --git a/internal/gitlabnet/lfsauthenticate/client.go b/internal/gitlabnet/lfsauthenticate/client.go
index fffc225..834cbe1 100644
--- a/internal/gitlabnet/lfsauthenticate/client.go
+++ b/internal/gitlabnet/lfsauthenticate/client.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"fmt"
"net/http"
"strings"
@@ -40,7 +41,7 @@ func NewClient(config *config.Config, args *commandargs.Shell) (*Client, error)
return &Client{config: config, client: client, args: args}, nil
}
-func (c *Client) Authenticate(operation, repo, userId string) (*Response, error) {
+func (c *Client) Authenticate(ctx context.Context, operation, repo, userId string) (*Response, error) {
request := &Request{Operation: operation, Repo: repo}
if c.args.GitlabKeyId != "" {
request.KeyId = c.args.GitlabKeyId
@@ -48,7 +49,7 @@ func (c *Client) Authenticate(operation, repo, userId string) (*Response, error)
request.UserId = strings.TrimPrefix(userId, "user-")
}
- response, err := c.client.Post("/lfs_authenticate", request)
+ response, err := c.client.Post(ctx, "/lfs_authenticate", request)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/lfsauthenticate/client_test.go b/internal/gitlabnet/lfsauthenticate/client_test.go
index 82e364b..2bd0451 100644
--- a/internal/gitlabnet/lfsauthenticate/client_test.go
+++ b/internal/gitlabnet/lfsauthenticate/client_test.go
@@ -1,6 +1,7 @@
package lfsauthenticate
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -85,7 +86,7 @@ func TestFailedRequests(t *testing.T) {
operation := tc.args.SshArgs[2]
- _, err = client.Authenticate(operation, repo, "")
+ _, err = client.Authenticate(context.Background(), operation, repo, "")
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
@@ -119,7 +120,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, err := NewClient(&config.Config{GitlabUrl: url}, args)
require.NoError(t, err)
- response, err := client.Authenticate(operation, repo, "")
+ response, err := client.Authenticate(context.Background(), operation, repo, "")
require.NoError(t, err)
expectedResponse := &Response{
diff --git a/internal/gitlabnet/personalaccesstoken/client.go b/internal/gitlabnet/personalaccesstoken/client.go
index 588bead..abbd395 100644
--- a/internal/gitlabnet/personalaccesstoken/client.go
+++ b/internal/gitlabnet/personalaccesstoken/client.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"errors"
"fmt"
"net/http"
@@ -42,13 +43,13 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetPersonalAccessToken(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) {
- requestBody, err := c.getRequestBody(args, name, scopes, expiresAt)
+func (c *Client) GetPersonalAccessToken(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*Response, error) {
+ requestBody, err := c.getRequestBody(ctx, args, name, scopes, expiresAt)
if err != nil {
return nil, err
}
- response, err := c.client.Post("/personal_access_token", requestBody)
+ response, err := c.client.Post(ctx, "/personal_access_token", requestBody)
if err != nil {
return nil, err
}
@@ -70,7 +71,7 @@ func parse(hr *http.Response) (*Response, error) {
return response, nil
}
-func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) {
+func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell, name string, scopes *[]string, expiresAt string) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
return nil, err
@@ -83,7 +84,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell, name string, scopes *[]
return requestBody, nil
}
- userInfo, err := client.GetByCommandArgs(args)
+ userInfo, err := client.GetByCommandArgs(ctx, args)
if err != nil {
return nil, err
}
diff --git a/internal/gitlabnet/personalaccesstoken/client_test.go b/internal/gitlabnet/personalaccesstoken/client_test.go
index de45975..140a7b2 100644
--- a/internal/gitlabnet/personalaccesstoken/client_test.go
+++ b/internal/gitlabnet/personalaccesstoken/client_test.go
@@ -1,6 +1,7 @@
package personalaccesstoken
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -90,7 +91,7 @@ func TestGetPersonalAccessTokenByKeyId(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: "0"}
result, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"read_api", "read_repository"}, "",
+ context.Background(), args, "newtoken", &[]string{"read_api", "read_repository"}, "",
)
assert.NoError(t, err)
response := &Response{
@@ -109,7 +110,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) {
args := &commandargs.Shell{GitlabUsername: "jane-doe"}
result, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.NoError(t, err)
response := &Response{true, "YXuxvUgCEmeePY3G1YAa", []string{"api"}, "", ""}
@@ -122,7 +123,7 @@ func TestMissingUser(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: "1"}
_, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.Equal(t, "missing user", err.Error())
}
@@ -157,7 +158,7 @@ func TestErrorResponses(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
resp, err := client.GetPersonalAccessToken(
- args, "newtoken", &[]string{"api"}, "",
+ context.Background(), args, "newtoken", &[]string{"api"}, "",
)
assert.EqualError(t, err, tc.expectedError)
diff --git a/internal/gitlabnet/twofactorrecover/client.go b/internal/gitlabnet/twofactorrecover/client.go
index d22daca..456f892 100644
--- a/internal/gitlabnet/twofactorrecover/client.go
+++ b/internal/gitlabnet/twofactorrecover/client.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"errors"
"fmt"
"net/http"
@@ -37,14 +38,14 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
-func (c *Client) GetRecoveryCodes(args *commandargs.Shell) ([]string, error) {
- requestBody, err := c.getRequestBody(args)
+func (c *Client) GetRecoveryCodes(ctx context.Context, args *commandargs.Shell) ([]string, error) {
+ requestBody, err := c.getRequestBody(ctx, args)
if err != nil {
return nil, err
}
- response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
+ response, err := c.client.Post(ctx, "/two_factor_recovery_codes", requestBody)
if err != nil {
return nil, err
}
@@ -66,7 +67,7 @@ func parse(hr *http.Response) ([]string, error) {
return response.RecoveryCodes, nil
}
-func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
+func (c *Client) getRequestBody(ctx context.Context, args *commandargs.Shell) (*RequestBody, error) {
client, err := discover.NewClient(c.config)
if err != nil {
@@ -77,7 +78,7 @@ func (c *Client) getRequestBody(args *commandargs.Shell) (*RequestBody, error) {
if args.GitlabKeyId != "" {
requestBody = &RequestBody{KeyId: args.GitlabKeyId}
} else {
- userInfo, err := client.GetByCommandArgs(args)
+ userInfo, err := client.GetByCommandArgs(ctx, args)
if err != nil {
return nil, err
diff --git a/internal/gitlabnet/twofactorrecover/client_test.go b/internal/gitlabnet/twofactorrecover/client_test.go
index 372afec..46291aa 100644
--- a/internal/gitlabnet/twofactorrecover/client_test.go
+++ b/internal/gitlabnet/twofactorrecover/client_test.go
@@ -1,6 +1,7 @@
package twofactorrecover
import (
+ "context"
"encoding/json"
"io/ioutil"
"net/http"
@@ -85,7 +86,7 @@ func TestGetRecoveryCodesByKeyId(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabKeyId: "0"}
- result, err := client.GetRecoveryCodes(args)
+ result, err := client.GetRecoveryCodes(context.Background(), args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
}
@@ -95,7 +96,7 @@ func TestGetRecoveryCodesByUsername(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabUsername: "jane-doe"}
- result, err := client.GetRecoveryCodes(args)
+ result, err := client.GetRecoveryCodes(context.Background(), args)
assert.NoError(t, err)
assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
}
@@ -105,7 +106,7 @@ func TestMissingUser(t *testing.T) {
defer cleanup()
args := &commandargs.Shell{GitlabKeyId: "1"}
- _, err := client.GetRecoveryCodes(args)
+ _, err := client.GetRecoveryCodes(context.Background(), args)
assert.Equal(t, "missing user", err.Error())
}
@@ -138,7 +139,7 @@ func TestErrorResponses(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
args := &commandargs.Shell{GitlabKeyId: tc.fakeId}
- resp, err := client.GetRecoveryCodes(args)
+ resp, err := client.GetRecoveryCodes(context.Background(), args)
assert.EqualError(t, err, tc.expectedError)
assert.Nil(t, resp)