summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
Diffstat (limited to 'client')
-rw-r--r--client/client_test.go21
-rw-r--r--client/gitlabnet.go28
-rw-r--r--client/httpclient_test.go7
-rw-r--r--client/httpsclient_test.go5
4 files changed, 29 insertions, 32 deletions
diff --git a/client/client_test.go b/client/client_test.go
index e92093a..e0650b2 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -78,7 +79,7 @@ func TestClients(t *testing.T) {
func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
t.Run("Successful get", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
hook := testhelper.SetupLogger()
data := map[string]string{"key": "value"}
- response, err := client.Post("/post_endpoint", data)
+ response, err := client.Post(context.Background(), "/post_endpoint", data)
require.NoError(t, err)
require.NotNil(t, response)
@@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/missing")
+ response, err := client.Get(context.Background(), "/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/missing", map[string]string{})
+ response, err := client.Post(context.Background(), "/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
@@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
func testErrorMessage(t *testing.T, client *GitlabNetClient) {
t.Run("Error with message for GET", func(t *testing.T) {
- response, err := client.Get("/error")
+ response, err := client.Get(context.Background(), "/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
t.Run("Error with message for POST", func(t *testing.T) {
- response, err := client.Post("/error", map[string]string{})
+ response, err := client.Post(context.Background(), "/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
@@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Get("/broken")
+ response, err := client.Get(context.Background(), "/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
- response, err := client.Post("/broken", map[string]string{})
+ response, err := client.Post(context.Background(), "/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
@@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
- response, err := client.Get("/auth")
+ response, err := client.Get(context.Background(), "/auth")
require.NoError(t, err)
require.NotNil(t, response)
@@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
t.Run("Authentication headers for POST", func(t *testing.T) {
- response, err := client.Post("/auth", map[string]string{})
+ response, err := client.Post(context.Background(), "/auth", map[string]string{})
require.NoError(t, err)
require.NotNil(t, response)
diff --git a/client/gitlabnet.go b/client/gitlabnet.go
index 0657ca0..b908d04 100644
--- a/client/gitlabnet.go
+++ b/client/gitlabnet.go
@@ -11,8 +11,9 @@ import (
"strings"
"time"
- log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
+
+ log "github.com/sirupsen/logrus"
)
const (
@@ -59,7 +60,7 @@ func normalizePath(path string) string {
return path
}
-func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) {
+func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) {
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
@@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str
jsonReader = bytes.NewReader(jsonData)
}
- correlationID, err := correlation.RandomID()
- ctx := context.Background()
-
- if err != nil {
- log.WithError(err).Warn("unable to generate correlation ID")
- } else {
- ctx = correlation.ContextWithCorrelation(ctx, correlationID)
- }
-
request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
if err != nil {
return nil, "", err
}
+ correlationID := correlation.ExtractFromContext(ctx)
+
return request, correlationID, nil
}
@@ -102,16 +96,16 @@ func parseError(resp *http.Response) error {
}
-func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
- return c.DoRequest(http.MethodGet, normalizePath(path), nil)
+func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
}
-func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
- return c.DoRequest(http.MethodPost, normalizePath(path), data)
+func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
+ return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
}
-func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
- request, correlationID, err := newRequest(method, c.httpClient.Host, path, data)
+func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
+ request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data)
if err != nil {
return nil, err
}
diff --git a/client/httpclient_test.go b/client/httpclient_test.go
index fce0cd5..97e1384 100644
--- a/client/httpclient_test.go
+++ b/client/httpclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"encoding/base64"
"fmt"
"io/ioutil"
@@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, username, password, requests)
defer cleanup()
- response, err := client.Get("/get_endpoint")
+ response, err := client.Get(context.Background(), "/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
- response, err = client.Post("/post_endpoint", nil)
+ response, err = client.Post(context.Background(), "/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
@@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, "", "", requests)
defer cleanup()
- _, err := client.Get("/empty_basic_auth")
+ _, err := client.Get(context.Background(), "/empty_basic_auth")
require.NoError(t, err)
}
diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go
index 1c7435f..0cf77e3 100644
--- a/client/httpsclient_test.go
+++ b/client/httpsclient_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"fmt"
"io/ioutil"
"net/http"
@@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
defer cleanup()
- response, err := client.Get("/hello")
+ response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
@@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
defer cleanup()
- _, err := client.Get("/hello")
+ _, err := client.Get(context.Background(), "/hello")
require.Error(t, err)
assert.Equal(t, err.Error(), "Internal API unreachable")