diff options
Diffstat (limited to 'client')
-rw-r--r-- | client/client_test.go | 21 | ||||
-rw-r--r-- | client/gitlabnet.go | 28 | ||||
-rw-r--r-- | client/httpclient_test.go | 7 | ||||
-rw-r--r-- | client/httpsclient_test.go | 5 |
4 files changed, 29 insertions, 32 deletions
diff --git a/client/client_test.go b/client/client_test.go index e92093a..e0650b2 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -78,7 +79,7 @@ func TestClients(t *testing.T) { func testSuccessfulGet(t *testing.T, client *GitlabNetClient) { t.Run("Successful get", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Get("/hello") + response, err := client.Get(context.Background(), "/hello") require.NoError(t, err) require.NotNil(t, response) @@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) { hook := testhelper.SetupLogger() data := map[string]string{"key": "value"} - response, err := client.Post("/post_endpoint", data) + response, err := client.Post(context.Background(), "/post_endpoint", data) require.NoError(t, err) require.NotNil(t, response) @@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) { func testMissing(t *testing.T, client *GitlabNetClient) { t.Run("Missing error for GET", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Get("/missing") + response, err := client.Get(context.Background(), "/missing") assert.EqualError(t, err, "Internal API error (404)") assert.Nil(t, response) @@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) { t.Run("Missing error for POST", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Post("/missing", map[string]string{}) + response, err := client.Post(context.Background(), "/missing", map[string]string{}) assert.EqualError(t, err, "Internal API error (404)") assert.Nil(t, response) @@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) { func testErrorMessage(t *testing.T, client *GitlabNetClient) { t.Run("Error with message for GET", func(t *testing.T) { - response, err := client.Get("/error") + response, err := client.Get(context.Background(), "/error") assert.EqualError(t, err, "Don't do that") assert.Nil(t, response) }) t.Run("Error with message for POST", func(t *testing.T) { - response, err := client.Post("/error", map[string]string{}) + response, err := client.Post(context.Background(), "/error", map[string]string{}) assert.EqualError(t, err, "Don't do that") assert.Nil(t, response) }) @@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) { t.Run("Broken request for GET", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Get("/broken") + response, err := client.Get(context.Background(), "/broken") assert.EqualError(t, err, "Internal API unreachable") assert.Nil(t, response) @@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) { t.Run("Broken request for POST", func(t *testing.T) { hook := testhelper.SetupLogger() - response, err := client.Post("/broken", map[string]string{}) + response, err := client.Post(context.Background(), "/broken", map[string]string{}) assert.EqualError(t, err, "Internal API unreachable") assert.Nil(t, response) @@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) { func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) { t.Run("Authentication headers for GET", func(t *testing.T) { - response, err := client.Get("/auth") + response, err := client.Get(context.Background(), "/auth") require.NoError(t, err) require.NotNil(t, response) @@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) { }) t.Run("Authentication headers for POST", func(t *testing.T) { - response, err := client.Post("/auth", map[string]string{}) + response, err := client.Post(context.Background(), "/auth", map[string]string{}) require.NoError(t, err) require.NotNil(t, response) diff --git a/client/gitlabnet.go b/client/gitlabnet.go index 0657ca0..b908d04 100644 --- a/client/gitlabnet.go +++ b/client/gitlabnet.go @@ -11,8 +11,9 @@ import ( "strings" "time" - log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/correlation" + + log "github.com/sirupsen/logrus" ) const ( @@ -59,7 +60,7 @@ func normalizePath(path string) string { return path } -func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) { +func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) { var jsonReader io.Reader if data != nil { jsonData, err := json.Marshal(data) @@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str jsonReader = bytes.NewReader(jsonData) } - correlationID, err := correlation.RandomID() - ctx := context.Background() - - if err != nil { - log.WithError(err).Warn("unable to generate correlation ID") - } else { - ctx = correlation.ContextWithCorrelation(ctx, correlationID) - } - request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader) if err != nil { return nil, "", err } + correlationID := correlation.ExtractFromContext(ctx) + return request, correlationID, nil } @@ -102,16 +96,16 @@ func parseError(resp *http.Response) error { } -func (c *GitlabNetClient) Get(path string) (*http.Response, error) { - return c.DoRequest(http.MethodGet, normalizePath(path), nil) +func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) { + return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil) } -func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) { - return c.DoRequest(http.MethodPost, normalizePath(path), data) +func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) { + return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data) } -func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) { - request, correlationID, err := newRequest(method, c.httpClient.Host, path, data) +func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) { + request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data) if err != nil { return nil, err } diff --git a/client/httpclient_test.go b/client/httpclient_test.go index fce0cd5..97e1384 100644 --- a/client/httpclient_test.go +++ b/client/httpclient_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/base64" "fmt" "io/ioutil" @@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) { client, cleanup := setup(t, username, password, requests) defer cleanup() - response, err := client.Get("/get_endpoint") + response, err := client.Get(context.Background(), "/get_endpoint") require.NoError(t, err) testBasicAuthHeaders(t, response) - response, err = client.Post("/post_endpoint", nil) + response, err = client.Post(context.Background(), "/post_endpoint", nil) require.NoError(t, err) testBasicAuthHeaders(t, response) } @@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) { client, cleanup := setup(t, "", "", requests) defer cleanup() - _, err := client.Get("/empty_basic_auth") + _, err := client.Get(context.Background(), "/empty_basic_auth") require.NoError(t, err) } diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go index 1c7435f..0cf77e3 100644 --- a/client/httpsclient_test.go +++ b/client/httpsclient_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "io/ioutil" "net/http" @@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) { client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned) defer cleanup() - response, err := client.Get("/hello") + response, err := client.Get(context.Background(), "/hello") require.NoError(t, err) require.NotNil(t, response) @@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) { client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false) defer cleanup() - _, err := client.Get("/hello") + _, err := client.Get(context.Background(), "/hello") require.Error(t, err) assert.Equal(t, err.Error(), "Internal API unreachable") |