summaryrefslogtreecommitdiff
path: root/internal/sshd/sshd_test.go
diff options
context:
space:
mode:
authorIgor Drozdov <idrozdov@gitlab.com>2022-07-01 11:02:59 +0000
committerIgor Drozdov <idrozdov@gitlab.com>2022-07-01 11:02:59 +0000
commit0d7ef238cb8c05eabaec85e62bec70a40147d1df (patch)
tree9179705f9e8b6ee309d456323fbaedaa70141c7e /internal/sshd/sshd_test.go
parent01f4e022c04b29b896eb383e6e6a33f96a6beeb1 (diff)
parent9b60ce49460876d0e599f2fec65f02856930dbcd (diff)
downloadgitlab-shell-0d7ef238cb8c05eabaec85e62bec70a40147d1df.tar.gz
Merge branch 'sshd-forwarded-for' into 'main'
Pass original IP from PROXY requests to internal API calls See merge request gitlab-org/gitlab-shell!665
Diffstat (limited to 'internal/sshd/sshd_test.go')
-rw-r--r--internal/sshd/sshd_test.go15
1 files changed, 13 insertions, 2 deletions
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index 8f52125..e3fbeeb 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -27,6 +27,7 @@ const (
var (
correlationId = ""
+ xForwardedFor = ""
)
func TestListenAndServe(t *testing.T) {
@@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
},
DestinationAddr: target,
}
+ xForwardedFor = "127.0.0.1"
+ defer func() {
+ xForwardedFor = "" // Cleanup for other test cases
+ }()
testCases := []struct {
desc string
@@ -132,9 +137,9 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
require.NoError(t, err)
}
- sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
+ sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
if sshConn != nil {
- sshConn.Close()
+ defer sshConn.Close()
}
if tc.isRejected {
@@ -142,6 +147,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
require.Regexp(t, "ssh: handshake failed", err.Error())
} else {
require.NoError(t, err)
+ client := ssh.NewClient(sshConn, sshChans, sshRequs)
+ defer client.Close()
+
+ holdSession(t, client)
}
})
}
@@ -306,6 +315,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
correlationId = r.Header.Get("X-Request-Id")
require.NotEmpty(t, correlationId)
+ require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
fmt.Fprint(w, `{"id": 1000, "key": "key"}`)
},
@@ -313,6 +323,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, correlationId, r.Header.Get("X-Request-Id"))
+ require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
},