diff options
Diffstat (limited to 'internal/handler')
-rw-r--r-- | internal/handler/exec.go | 12 | ||||
-rw-r--r-- | internal/handler/exec_test.go | 26 |
2 files changed, 19 insertions, 19 deletions
diff --git a/internal/handler/exec.go b/internal/handler/exec.go index 5ead63e..ac59dab 100644 --- a/internal/handler/exec.go +++ b/internal/handler/exec.go @@ -61,10 +61,10 @@ func (gc *GitalyCommand) RunGitalyCommand(handler GitalyHandlerFunc) error { // PrepareContext wraps a given context with a correlation ID and logs the command to // be run. -func (gc *GitalyCommand) PrepareContext(ctx context.Context, repository *pb.Repository, response *accessverifier.Response, protocol string) (context.Context, context.CancelFunc) { +func (gc *GitalyCommand) PrepareContext(ctx context.Context, repository *pb.Repository, response *accessverifier.Response, env sshenv.Env) (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(ctx) - gc.LogExecution(repository, response, protocol) + gc.LogExecution(repository, response, env) if response.CorrelationID != "" { ctx = correlation.ContextWithCorrelation(ctx, response.CorrelationID) @@ -78,13 +78,13 @@ func (gc *GitalyCommand) PrepareContext(ctx context.Context, repository *pb.Repo md.Append("key_type", response.KeyType) md.Append("user_id", response.UserId) md.Append("username", response.Username) - md.Append("remote_ip", sshenv.LocalAddr()) + md.Append("remote_ip", env.RemoteAddr) ctx = metadata.NewOutgoingContext(ctx, md) return ctx, cancel } -func (gc *GitalyCommand) LogExecution(repository *pb.Repository, response *accessverifier.Response, protocol string) { +func (gc *GitalyCommand) LogExecution(repository *pb.Repository, response *accessverifier.Response, env sshenv.Env) { fields := log.Fields{ "command": gc.ServiceName, "correlation_id": response.CorrelationID, @@ -92,8 +92,8 @@ func (gc *GitalyCommand) LogExecution(repository *pb.Repository, response *acces "gl_repository": repository.GlRepository, "user_id": response.UserId, "username": response.Username, - "git_protocol": protocol, - "remote_ip": sshenv.LocalAddr(), + "git_protocol": env.GitProtocolVersion, + "remote_ip": env.RemoteAddr, "gl_key_type": response.KeyType, "gl_key_id": response.KeyId, } diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go index 0dbd538..915bf5a 100644 --- a/internal/handler/exec_test.go +++ b/internal/handler/exec_test.go @@ -12,7 +12,7 @@ import ( pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" - "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" ) func makeHandler(t *testing.T, err error) func(context.Context, *grpc.ClientConn) (int32, error) { @@ -89,12 +89,12 @@ func TestGetConnMetadata(t *testing.T) { func TestPrepareContext(t *testing.T) { tests := []struct { - name string - gc *GitalyCommand - sshConnectionEnv string - repo *pb.Repository - response *accessverifier.Response - want map[string]string + name string + gc *GitalyCommand + env sshenv.Env + repo *pb.Repository + response *accessverifier.Response + want map[string]string }{ { name: "client_identity", @@ -102,7 +102,11 @@ func TestPrepareContext(t *testing.T) { Config: &config.Config{}, Address: "tcp://localhost:9999", }, - sshConnectionEnv: "10.0.0.1 1234 127.0.0.1 5678", + env: sshenv.Env{ + GitProtocolVersion: "protocol", + IsSSHConnection: true, + RemoteAddr: "10.0.0.1", + }, repo: &pb.Repository{ StorageName: "default", RelativePath: "@hashed/5f/9c/5f9c4ab08cac7457e9111a30e4664920607ea2c115a1433d7be98e97e64244ca.git", @@ -128,13 +132,9 @@ func TestPrepareContext(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cleanup, err := testhelper.Setenv("SSH_CONNECTION", tt.sshConnectionEnv) - require.NoError(t, err) - defer cleanup() - ctx := context.Background() - ctx, cancel := tt.gc.PrepareContext(ctx, tt.repo, tt.response, "protocol") + ctx, cancel := tt.gc.PrepareContext(ctx, tt.repo, tt.response, tt.env) defer cancel() md, exists := metadata.FromOutgoingContext(ctx) |