summaryrefslogtreecommitdiff
path: root/go
diff options
context:
space:
mode:
Diffstat (limited to 'go')
-rw-r--r--go/internal/gitlabnet/client.go12
-rw-r--r--go/internal/gitlabnet/client_test.go30
2 files changed, 42 insertions, 0 deletions
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
index dacb1d6..26c24d4 100644
--- a/go/internal/gitlabnet/client.go
+++ b/go/internal/gitlabnet/client.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
+ "os"
"strings"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
@@ -109,6 +110,9 @@ func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.R
request.Header.Set(secretHeaderName, encodedSecret)
request.Header.Add("Content-Type", "application/json")
+ ip := ipAddr()
+ request.Header.Add("X_FORWARDED_FOR", ip)
+
request.Close = true
response, err := c.httpClient.Do(request)
@@ -123,6 +127,14 @@ func (c *GitlabClient) DoRequest(method, path string, data interface{}) (*http.R
return response, nil
}
+func ipAddr() string {
+ address := os.Getenv("SSH_CONNECTION")
+ if address != "" {
+ return strings.Fields(address)[0]
+ }
+ return address
+}
+
func ParseJSON(hr *http.Response, response interface{}) error {
if err := json.NewDecoder(hr.Body).Decode(response); err != nil {
return ParsingError
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
index e8499dc..debe618 100644
--- a/go/internal/gitlabnet/client_test.go
+++ b/go/internal/gitlabnet/client_test.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io/ioutil"
"net/http"
+ "os"
"path"
"testing"
@@ -51,6 +52,13 @@ func TestClients(t *testing.T) {
},
},
{
+ Path: "/api/v4/internal/with_ip",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ header := r.Header.Get("X_FORWARDED_FOR")
+ require.Equal(t, header, "127.0.0.1")
+ },
+ },
+ {
Path: "/api/v4/internal/error",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@@ -110,6 +118,7 @@ func TestClients(t *testing.T) {
testMissing(t, client)
testErrorMessage(t, client)
testAuthenticationHeader(t, client)
+ testXForwardedForHeader(t, client)
})
}
}
@@ -217,3 +226,24 @@ func testAuthenticationHeader(t *testing.T, client *GitlabClient) {
assert.Equal(t, "sssh, it's a secret", string(header))
})
}
+
+func testXForwardedForHeader(t *testing.T, client *GitlabClient) {
+ t.Run("X-Forwarded-For for GET", func(t *testing.T) {
+ os.Setenv("SSH_CONNECTION", "127.0.0.1 0")
+ response, err := client.Get("/with_ip")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+ })
+
+ t.Run("X-Forwarded-For for POST", func(t *testing.T) {
+ data := map[string]string{"key": "value"}
+ os.Setenv("SSH_CONNECTION", "127.0.0.1 0")
+ response, err := client.Post("/with_ip", data)
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+ })
+}