diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/gitlabnet/accessverifier/client.go | 18 | ||||
-rw-r--r-- | internal/gitlabnet/client.go | 16 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 17 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 15 |
4 files changed, 46 insertions, 20 deletions
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go index c46a16f..adeccd6 100644 --- a/internal/gitlabnet/accessverifier/client.go +++ b/internal/gitlabnet/accessverifier/client.go @@ -3,7 +3,6 @@ package accessverifier import ( "context" "fmt" - "net" "net/http" pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" @@ -86,7 +85,7 @@ func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action com request.KeyId = args.GitlabKeyId } - request.CheckIp = parseIP(args.Env.RemoteAddr) + request.CheckIp = gitlabnet.ParseIP(args.Env.RemoteAddr) response, err := c.client.Post(ctx, "/allowed", request) if err != nil { @@ -117,18 +116,3 @@ func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) { func (r *Response) IsCustomAction() bool { return r.StatusCode == http.StatusMultipleChoices } - -func parseIP(remoteAddr string) string { - // The remoteAddr field can be filled by: - // 1. An IP address via the SSH_CONNECTION environment variable - // 2. A host:port combination via the PROXY protocol - ip, _, err := net.SplitHostPort(remoteAddr) - - // If we don't have a port or can't parse this address for some reason, - // just return the original string. - if err != nil { - return remoteAddr - } - - return ip -} diff --git a/internal/gitlabnet/client.go b/internal/gitlabnet/client.go index 39c3320..9bcf6db 100644 --- a/internal/gitlabnet/client.go +++ b/internal/gitlabnet/client.go @@ -3,6 +3,7 @@ package gitlabnet import ( "encoding/json" "fmt" + "net" "net/http" "gitlab.com/gitlab-org/gitlab-shell/client" @@ -34,3 +35,18 @@ func ParseJSON(hr *http.Response, response interface{}) error { return nil } + +func ParseIP(remoteAddr string) string { + // The remoteAddr field can be filled by: + // 1. An IP address via the SSH_CONNECTION environment variable + // 2. A host:port combination via the PROXY protocol + ip, _, err := net.SplitHostPort(remoteAddr) + + // If we don't have a port or can't parse this address for some reason, + // just return the original string. + if err != nil { + return remoteAddr + } + + return ip +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 43c4d7b..d275193 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -12,7 +12,9 @@ import ( "github.com/pires/go-proxyproto" "golang.org/x/crypto/ssh" + "gitlab.com/gitlab-org/gitlab-shell/client" "gitlab.com/gitlab-org/gitlab-shell/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/internal/metrics" "gitlab.com/gitlab-org/labkit/correlation" @@ -145,13 +147,26 @@ func (s *Server) getStatus() status { return s.status } +func contextWithValues(parent context.Context, nconn net.Conn) context.Context { + ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID()) + + // If we're dealing with a PROXY connection, register the original requester's IP + mconn, ok := nconn.(*proxyproto.Conn) + if ok { + ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String()) + ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip) + } + + return ctx +} + func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { defer s.wg.Done() metrics.SshdConnectionsInFlight.Inc() defer metrics.SshdConnectionsInFlight.Dec() - ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) + ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn)) defer cancel() go func() { <-ctx.Done() diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index 8f52125..e3fbeeb 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -27,6 +27,7 @@ const ( var ( correlationId = "" + xForwardedFor = "" ) func TestListenAndServe(t *testing.T) { @@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin }, DestinationAddr: target, } + xForwardedFor = "127.0.0.1" + defer func() { + xForwardedFor = "" // Cleanup for other test cases + }() testCases := []struct { desc string @@ -132,9 +137,9 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin require.NoError(t, err) } - sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) + sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t)) if sshConn != nil { - sshConn.Close() + defer sshConn.Close() } if tc.isRejected { @@ -142,6 +147,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin require.Regexp(t, "ssh: handshake failed", err.Error()) } else { require.NoError(t, err) + client := ssh.NewClient(sshConn, sshChans, sshRequs) + defer client.Close() + + holdSession(t, client) } }) } @@ -306,6 +315,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex correlationId = r.Header.Get("X-Request-Id") require.NotEmpty(t, correlationId) + require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For")) fmt.Fprint(w, `{"id": 1000, "key": "key"}`) }, @@ -313,6 +323,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex Path: "/api/v4/internal/discover", Handler: func(w http.ResponseWriter, r *http.Request) { require.Equal(t, correlationId, r.Header.Get("X-Request-Id")) + require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For")) fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`) }, |