diff options
Diffstat (limited to 'go')
-rw-r--r-- | go/internal/gitlabnet/client.go | 12 | ||||
-rw-r--r-- | go/internal/gitlabnet/client_test.go | 30 |
2 files changed, 42 insertions, 0 deletions
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go index dacb1d6..26c24d4 100644 --- a/go/internal/gitlabnet/client.go +++ b/go/internal/gitlabnet/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "os" "strings" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" @@ -109,6 +110,9 @@ func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.R request.Header.Set(secretHeaderName, encodedSecret) request.Header.Add("Content-Type", "application/json") + ip := ipAddr() + request.Header.Add("X_FORWARDED_FOR", ip) + request.Close = true response, err := c.httpClient.Do(request) @@ -123,6 +127,14 @@ func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.R return response, nil } +func ipAddr() string { + address := os.Getenv("SSH_CONNECTION") + if address != "" { + return strings.Fields(address)[0] + } + return address +} + func ParseJSON(hr *http.Response, response interface{}) error { if err := json.NewDecoder(hr.Body).Decode(response); err != nil { return ParsingError diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go index e8499dc..debe618 100644 --- a/go/internal/gitlabnet/client_test.go +++ b/go/internal/gitlabnet/client_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net/http" + "os" "path" "testing" @@ -51,6 +52,13 @@ func TestClients(t *testing.T) { }, }, { + Path: "/api/v4/internal/with_ip", + Handler: func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("X_FORWARDED_FOR") + require.Equal(t, header, "127.0.0.1") + }, + }, + { Path: "/api/v4/internal/error", Handler: func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -110,6 +118,7 @@ func TestClients(t *testing.T) { testMissing(t, client) testErrorMessage(t, client) testAuthenticationHeader(t, client) + testXForwardedForHeader(t, client) }) } } @@ -217,3 +226,24 @@ func testAuthenticationHeader(t *testing.T, client *GitlabClient) { assert.Equal(t, "sssh, it's a secret", string(header)) }) } + +func testXForwardedForHeader(t *testing.T, client *GitlabClient) { + t.Run("X-Forwarded-For for GET", func(t *testing.T) { + os.Setenv("SSH_CONNECTION", "127.0.0.1 0") + response, err := client.Get("/with_ip") + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + }) + + t.Run("X-Forwarded-For for POST", func(t *testing.T) { + data := map[string]string{"key": "value"} + os.Setenv("SSH_CONNECTION", "127.0.0.1 0") + response, err := client.Post("/with_ip", data) + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + }) +} |