summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/gitlabnet/accessverifier/client.go18
-rw-r--r--internal/gitlabnet/client.go16
-rw-r--r--internal/sshd/sshd.go17
-rw-r--r--internal/sshd/sshd_test.go15
4 files changed, 46 insertions, 20 deletions
diff --git a/internal/gitlabnet/accessverifier/client.go b/internal/gitlabnet/accessverifier/client.go
index c46a16f..adeccd6 100644
--- a/internal/gitlabnet/accessverifier/client.go
+++ b/internal/gitlabnet/accessverifier/client.go
@@ -3,7 +3,6 @@ package accessverifier
import (
"context"
"fmt"
- "net"
"net/http"
pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
@@ -86,7 +85,7 @@ func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action com
request.KeyId = args.GitlabKeyId
}
- request.CheckIp = parseIP(args.Env.RemoteAddr)
+ request.CheckIp = gitlabnet.ParseIP(args.Env.RemoteAddr)
response, err := c.client.Post(ctx, "/allowed", request)
if err != nil {
@@ -117,18 +116,3 @@ func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) {
func (r *Response) IsCustomAction() bool {
return r.StatusCode == http.StatusMultipleChoices
}
-
-func parseIP(remoteAddr string) string {
- // The remoteAddr field can be filled by:
- // 1. An IP address via the SSH_CONNECTION environment variable
- // 2. A host:port combination via the PROXY protocol
- ip, _, err := net.SplitHostPort(remoteAddr)
-
- // If we don't have a port or can't parse this address for some reason,
- // just return the original string.
- if err != nil {
- return remoteAddr
- }
-
- return ip
-}
diff --git a/internal/gitlabnet/client.go b/internal/gitlabnet/client.go
index 39c3320..9bcf6db 100644
--- a/internal/gitlabnet/client.go
+++ b/internal/gitlabnet/client.go
@@ -3,6 +3,7 @@ package gitlabnet
import (
"encoding/json"
"fmt"
+ "net"
"net/http"
"gitlab.com/gitlab-org/gitlab-shell/client"
@@ -34,3 +35,18 @@ func ParseJSON(hr *http.Response, response interface{}) error {
return nil
}
+
+func ParseIP(remoteAddr string) string {
+ // The remoteAddr field can be filled by:
+ // 1. An IP address via the SSH_CONNECTION environment variable
+ // 2. A host:port combination via the PROXY protocol
+ ip, _, err := net.SplitHostPort(remoteAddr)
+
+ // If we don't have a port or can't parse this address for some reason,
+ // just return the original string.
+ if err != nil {
+ return remoteAddr
+ }
+
+ return ip
+}
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 43c4d7b..d275193 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -12,7 +12,9 @@ import (
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/ssh"
+ "gitlab.com/gitlab-org/gitlab-shell/client"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
"gitlab.com/gitlab-org/labkit/correlation"
@@ -145,13 +147,26 @@ func (s *Server) getStatus() status {
return s.status
}
+func contextWithValues(parent context.Context, nconn net.Conn) context.Context {
+ ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID())
+
+ // If we're dealing with a PROXY connection, register the original requester's IP
+ mconn, ok := nconn.(*proxyproto.Conn)
+ if ok {
+ ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String())
+ ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip)
+ }
+
+ return ctx
+}
+
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
defer s.wg.Done()
metrics.SshdConnectionsInFlight.Inc()
defer metrics.SshdConnectionsInFlight.Dec()
- ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
+ ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn))
defer cancel()
go func() {
<-ctx.Done()
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 8f52125..e3fbeeb 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -27,6 +27,7 @@ const (
var (
correlationId = ""
+ xForwardedFor = ""
)
func TestListenAndServe(t *testing.T) {
@@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
},
DestinationAddr: target,
}
+ xForwardedFor = "127.0.0.1"
+ defer func() {
+ xForwardedFor = "" // Cleanup for other test cases
+ }()
testCases := []struct {
desc string
@@ -132,9 +137,9 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
require.NoError(t, err)
}
- sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
+ sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
if sshConn != nil {
- sshConn.Close()
+ defer sshConn.Close()
}
if tc.isRejected {
@@ -142,6 +147,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
require.Regexp(t, "ssh: handshake failed", err.Error())
} else {
require.NoError(t, err)
+ client := ssh.NewClient(sshConn, sshChans, sshRequs)
+ defer client.Close()
+
+ holdSession(t, client)
}
})
}
@@ -306,6 +315,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
correlationId = r.Header.Get("X-Request-Id")
require.NotEmpty(t, correlationId)
+ require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
fmt.Fprint(w, `{"id": 1000, "key": "key"}`)
},
@@ -313,6 +323,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, correlationId, r.Header.Get("X-Request-Id"))
+ require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
},