summaryrefslogtreecommitdiff
path: root/internal/handler
diff options
context:
space:
mode:
Diffstat (limited to 'internal/handler')
-rw-r--r--internal/handler/exec.go12
-rw-r--r--internal/handler/exec_test.go26
2 files changed, 19 insertions, 19 deletions
diff --git a/internal/handler/exec.go b/internal/handler/exec.go
index 5ead63e..ac59dab 100644
--- a/internal/handler/exec.go
+++ b/internal/handler/exec.go
@@ -61,10 +61,10 @@ func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error {
// PrepareContext wraps a given context with a correlation ID and logs the command to
// be run.
-func (gc *GitalyCommand) PrepareContext(ctx context.Context, repository *pb.Repository, response *accessverifier.Response, protocol string) (context.Context, context.CancelFunc) {
+func (gc *GitalyCommand) PrepareContext(ctx context.Context, repository *pb.Repository, response *accessverifier.Response, env sshenv.Env) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx)
- gc.LogExecution(repository, response, protocol)
+ gc.LogExecution(repository, response, env)
if response.CorrelationID != "" {
ctx = correlation.ContextWithCorrelation(ctx, response.CorrelationID)
@@ -78,13 +78,13 @@ func (gc *GitalyCommand) PrepareContext(ctx context.Context, repository *pb.Repo
md.Append("key_type", response.KeyType)
md.Append("user_id", response.UserId)
md.Append("username", response.Username)
- md.Append("remote_ip", sshenv.LocalAddr())
+ md.Append("remote_ip", env.RemoteAddr)
ctx = metadata.NewOutgoingContext(ctx, md)
return ctx, cancel
}
-func (gc *GitalyCommand) LogExecution(repository *pb.Repository, response *accessverifier.Response, protocol string) {
+func (gc *GitalyCommand) LogExecution(repository *pb.Repository, response *accessverifier.Response, env sshenv.Env) {
fields := log.Fields{
"command": gc.ServiceName,
"correlation_id": response.CorrelationID,
@@ -92,8 +92,8 @@ func (gc *GitalyCommand) LogExecution(repository *pb.Repository, response *acces
"gl_repository": repository.GlRepository,
"user_id": response.UserId,
"username": response.Username,
- "git_protocol": protocol,
- "remote_ip": sshenv.LocalAddr(),
+ "git_protocol": env.GitProtocolVersion,
+ "remote_ip": env.RemoteAddr,
"gl_key_type": response.KeyType,
"gl_key_id": response.KeyId,
}
diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go
index 0dbd538..915bf5a 100644
--- a/internal/handler/exec_test.go
+++ b/internal/handler/exec_test.go
@@ -12,7 +12,7 @@ import (
pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
- "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
)
func makeHandler(t *testing.T, err error) func(context.Context, *grpc.ClientConn) (int32, error) {
@@ -89,12 +89,12 @@ func TestGetConnMetadata(t *testing.T) {
func TestPrepareContext(t *testing.T) {
tests := []struct {
- name string
- gc *GitalyCommand
- sshConnectionEnv string
- repo *pb.Repository
- response *accessverifier.Response
- want map[string]string
+ name string
+ gc *GitalyCommand
+ env sshenv.Env
+ repo *pb.Repository
+ response *accessverifier.Response
+ want map[string]string
}{
{
name: "client_identity",
@@ -102,7 +102,11 @@ func TestPrepareContext(t *testing.T) {
Config: &config.Config{},
Address: "tcp://localhost:9999",
},
- sshConnectionEnv: "10.0.0.1 1234 127.0.0.1 5678",
+ env: sshenv.Env{
+ GitProtocolVersion: "protocol",
+ IsSSHConnection: true,
+ RemoteAddr: "10.0.0.1",
+ },
repo: &pb.Repository{
StorageName: "default",
RelativePath: "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git",
@@ -128,13 +132,9 @@ func TestPrepareContext(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- cleanup, err := testhelper.Setenv("SSH_CONNECTION", tt.sshConnectionEnv)
- require.NoError(t, err)
- defer cleanup()
-
ctx := context.Background()
- ctx, cancel := tt.gc.PrepareContext(ctx, tt.repo, tt.response, "protocol")
+ ctx, cancel := tt.gc.PrepareContext(ctx, tt.repo, tt.response, tt.env)
defer cancel()
md, exists := metadata.FromOutgoingContext(ctx)