diff options
author | Igor Drozdov <idrozdov@gitlab.com> | 2022-07-01 11:02:59 +0000 |
---|---|---|
committer | Igor Drozdov <idrozdov@gitlab.com> | 2022-07-01 11:02:59 +0000 |
commit | 0d7ef238cb8c05eabaec85e62bec70a40147d1df (patch) | |
tree | 9179705f9e8b6ee309d456323fbaedaa70141c7e /internal/sshd | |
parent | 01f4e022c04b29b896eb383e6e6a33f96a6beeb1 (diff) | |
parent | 9b60ce49460876d0e599f2fec65f02856930dbcd (diff) | |
download | gitlab-shell-0d7ef238cb8c05eabaec85e62bec70a40147d1df.tar.gz |
Merge branch 'sshd-forwarded-for' into 'main'
Pass original IP from PROXY requests to internal API calls
See merge request gitlab-org/gitlab-shell!665
Diffstat (limited to 'internal/sshd')
-rw-r--r-- | internal/sshd/sshd.go | 17 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 15 |
2 files changed, 29 insertions, 3 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 43c4d7b..d275193 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -12,7 +12,9 @@ import ( "github.com/pires/go-proxyproto" "golang.org/x/crypto/ssh" + "gitlab.com/gitlab-org/gitlab-shell/client" "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/internal/metrics" "gitlab.com/gitlab-org/labkit/correlation" @@ -145,13 +147,26 @@ func (s *Server) getStatus() status { return s.status } +func contextWithValues(parent context.Context, nconn net.Conn) context.Context { + ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID()) + + // If we're dealing with a PROXY connection, register the original requester's IP + mconn, ok := nconn.(*proxyproto.Conn) + if ok { + ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String()) + ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip) + } + + return ctx +} + func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { defer s.wg.Done() metrics.SshdConnectionsInFlight.Inc() defer metrics.SshdConnectionsInFlight.Dec() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) + ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn)) defer cancel() go func() { <-ctx.Done() diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 8f52125..e3fbeeb 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -27,6 +27,7 @@ const ( var ( correlationId = "" + xForwardedFor = "" ) func TestListenAndServe(t *testing.T) { @@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin }, DestinationAddr: target, } + xForwardedFor = "127.0.0.1" + defer func() { + xForwardedFor = "" // Cleanup for other test cases + }() testCases := []struct { desc string @@ -132,9 +137,9 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin require.NoError(t, err) } - sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) + sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) if sshConn != nil { - sshConn.Close() + defer sshConn.Close() } if tc.isRejected { @@ -142,6 +147,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin require.Regexp(t, "ssh: handshake failed", err.Error()) } else { require.NoError(t, err) + client := ssh.NewClient(sshConn, sshChans, sshRequs) + defer client.Close() + + holdSession(t, client) } }) } @@ -306,6 +315,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex correlationId = r.Header.Get("X-Request-Id") require.NotEmpty(t, correlationId) + require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For")) fmt.Fprint(w, `{"id": 1000, "key": "key"}`) }, @@ -313,6 +323,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex Path: "/api/v4/internal/discover", Handler: func(w http.ResponseWriter, r *http.Request) { require.Equal(t, correlationId, r.Header.Get("X-Request-Id")) + require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For")) fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`) }, |