diff options
Diffstat (limited to 'internal/gitlabnet')
-rw-r--r-- | internal/gitlabnet/accessverifier/client.go | 18 | ||||
-rw-r--r-- | internal/gitlabnet/accessverifier/client_test.go | 72 |
2 files changed, 89 insertions, 1 deletions
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go index bce32cf..c46a16f 100644 --- a/internal/gitlabnet/accessverifier/client.go +++ b/internal/gitlabnet/accessverifier/client.go @@ -3,6 +3,7 @@ package accessverifier import ( "context" "fmt" + "net" "net/http" pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" @@ -85,7 +86,7 @@ func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action com request.KeyId = args.GitlabKeyId } - request.CheckIp = args.Env.RemoteAddr + request.CheckIp = parseIP(args.Env.RemoteAddr) response, err := c.client.Post(ctx, "/allowed", request) if err != nil { @@ -116,3 +117,18 @@ func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) { func (r *Response) IsCustomAction() bool { return r.StatusCode == http.StatusMultipleChoices } + +func parseIP(remoteAddr string) string { + // The remoteAddr field can be filled by: + // 1. An IP address via the SSH_CONNECTION environment variable + // 2. A host:port combination via the PROXY protocol + ip, _, err := net.SplitHostPort(remoteAddr) + + // If we don't have a port or can't parse this address for some reason, + // just return the original string. + if err != nil { + return remoteAddr + } + + return ip +} diff --git a/internal/gitlabnet/accessverifier/client_test.go b/internal/gitlabnet/accessverifier/client_test.go index 13e2d2c..6e426c9 100644 --- a/internal/gitlabnet/accessverifier/client_test.go +++ b/internal/gitlabnet/accessverifier/client_test.go @@ -14,6 +14,7 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/client/testserver" "gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" ) @@ -180,6 +181,51 @@ func TestErrorResponses(t *testing.T) { } } +func TestCheckIP(t *testing.T) { + testCases := []struct { + desc string + remoteAddr string + expectedCheckIp string + }{ + { + desc: "IPv4 address", + remoteAddr: "18.245.0.42", + expectedCheckIp: "18.245.0.42", + }, + { + desc: "IPv6 address", + remoteAddr: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + expectedCheckIp: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + }, + { + desc: "Host and port", + remoteAddr: "18.245.0.42:6345", + expectedCheckIp: "18.245.0.42", + }, + { + desc: "IPv6 host and port", + remoteAddr: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:80", + expectedCheckIp: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + }, + { + desc: "Bad remote addr", + remoteAddr: "[127.0", + expectedCheckIp: "[127.0", + }, + } + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client := setupWithApiInspector(t, + func(r *Request) { + require.Equal(t, tc.expectedCheckIp, r.CheckIp) + }) + + sshEnv := sshenv.Env{RemoteAddr: tc.remoteAddr} + client.Verify(context.Background(), &commandargs.Shell{Env: sshEnv}, uploadPackAction, repo) + }) + } +} + type testResponse struct { body []byte status int @@ -225,3 +271,29 @@ func setup(t *testing.T, userResponses, keyResponses map[string]testResponse) *C return client } + +func setupWithApiInspector(t *testing.T, inspector func(*Request)) *Client { + t.Helper() + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/allowed", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + require.NoError(t, err) + + var requestBody *Request + err = json.Unmarshal(b, &requestBody) + require.NoError(t, err) + + inspector(requestBody) + }, + }, + } + + url := testserver.StartSocketHttpServer(t, requests) + + client, err := NewClient(&config.Config{GitlabUrl: url}) + require.NoError(t, err) + + return client +} |