diff options
Diffstat (limited to 'internal/handler')
-rw-r--r-- | internal/handler/exec.go | 89 | ||||
-rw-r--r-- | internal/handler/exec_test.go | 26 |
2 files changed, 52 insertions, 63 deletions
diff --git a/internal/handler/exec.go b/internal/handler/exec.go index 723d655..1a7716e 100644 --- a/internal/handler/exec.go +++ b/internal/handler/exec.go @@ -6,20 +6,21 @@ import ( "strconv" "strings" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" gitalyauth "gitlab.com/gitlab-org/gitaly/auth" "gitlab.com/gitlab-org/gitaly/client" pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier" "gitlab.com/gitlab-org/gitlab-shell/internal/sshenv" "gitlab.com/gitlab-org/labkit/correlation" grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" - "gitlab.com/gitlab-org/labkit/tracing" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" + grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" ) // GitalyHandlerFunc implementations are responsible for making @@ -27,12 +28,6 @@ import ( // and returning an error from the Gitaly call. type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn) (int32, error) -type GitalyConn struct { - ctx context.Context - conn *grpc.ClientConn - close func() -} - type GitalyCommand struct { Config *config.Config ServiceName string @@ -45,14 +40,14 @@ type GitalyCommand struct { // through GitLab-Shell. It ensures that logging, tracing and other // common concerns are configured before executing the `handler`. func (gc *GitalyCommand) RunGitalyCommand(ctx context.Context, handler GitalyHandlerFunc) error { - gitalyConn, err := getConn(ctx, gc) + conn, err := getConn(ctx, gc) if err != nil { return err } + defer conn.Close() - defer gitalyConn.close() - - _, err = handler(gitalyConn.ctx, gitalyConn.conn) + childCtx := withOutgoingMetadata(ctx, gc.Features) + _, err = handler(childCtx, conn) return err } @@ -106,23 +101,43 @@ func withOutgoingMetadata(ctx context.Context, features map[string]string) conte return metadata.NewOutgoingContext(ctx, md) } -func getConn(ctx context.Context, gc *GitalyCommand) (*GitalyConn, error) { +func getConn(ctx context.Context, gc *GitalyCommand) (*grpc.ClientConn, error) { if gc.Address == "" { return nil, fmt.Errorf("no gitaly_address given") } + serviceName := correlation.ExtractClientNameFromContext(ctx) + if serviceName == "" { + log.Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown") + + serviceName = "gitlab-shell-unknown" + } + + serviceName = fmt.Sprintf("%s-%s", serviceName, gc.ServiceName) + connOpts := client.DefaultDialOpts - connOpts = append(connOpts, + connOpts = append( + connOpts, grpc.WithStreamInterceptor( - grpccorrelation.StreamClientCorrelationInterceptor( - grpccorrelation.WithClientName(executable.GitlabShell), + grpc_middleware.ChainStreamClient( + grpctracing.StreamClientTracingInterceptor(), + grpc_prometheus.StreamClientInterceptor, + grpccorrelation.StreamClientCorrelationInterceptor( + grpccorrelation.WithClientName(serviceName), + ), ), ), + grpc.WithUnaryInterceptor( - grpccorrelation.UnaryClientCorrelationInterceptor( - grpccorrelation.WithClientName(executable.GitlabShell), + grpc_middleware.ChainUnaryClient( + grpctracing.UnaryClientTracingInterceptor(), + grpc_prometheus.UnaryClientInterceptor, + grpccorrelation.UnaryClientCorrelationInterceptor( + grpccorrelation.WithClientName(serviceName), + ), ), - )) + ), + ) if gc.Token != "" { connOpts = append(connOpts, @@ -130,35 +145,5 @@ func getConn(ctx context.Context, gc *GitalyCommand) (*GitalyConn, error) { ) } - // Configure distributed tracing - serviceName := fmt.Sprintf("gitlab-shell-%v", gc.ServiceName) - closer := tracing.Initialize( - tracing.WithServiceName(serviceName), - - // For GitLab-Shell, we explicitly initialize tracing from a config file - // instead of the default environment variable (using GITLAB_TRACING) - // This decision was made owing to the difficulty in passing environment - // variables into GitLab-Shell processes. - // - // Processes are spawned as children of the SSH daemon, which tightly - // controls environment variables; doing this means we don't have to - // enable PermitUserEnvironment - tracing.WithConnectionString(gc.Config.GitlabTracing), - ) - - childCtx, finished := tracing.ExtractFromEnv(ctx) - childCtx = withOutgoingMetadata(childCtx, gc.Features) - - conn, err := client.DialContext(childCtx, gc.Address, connOpts) - if err != nil { - return nil, err - } - - finish := func() { - finished() - closer.Close() - conn.Close() - } - - return &GitalyConn{ctx: childCtx, conn: conn, close: finish}, nil + return client.DialContext(ctx, gc.Address, connOpts) } diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go index 6f84709..9b8fee8 100644 --- a/internal/handler/exec_test.go +++ b/internal/handler/exec_test.go @@ -45,7 +45,7 @@ func TestMissingGitalyAddress(t *testing.T) { require.EqualError(t, err, "no gitaly_address given") } -func TestGetConnMetadata(t *testing.T) { +func TestRunGitalyCommandMetadata(t *testing.T) { tests := []struct { name string gc *GitalyCommand @@ -70,19 +70,23 @@ func TestGetConnMetadata(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - conn, err := getConn(context.Background(), tt.gc) - require.NoError(t, err) + cmd := tt.gc - md, exists := metadata.FromOutgoingContext(conn.ctx) - require.True(t, exists) - require.Equal(t, len(tt.want), md.Len()) + err := cmd.RunGitalyCommand(context.Background(), func(ctx context.Context, _ *grpc.ClientConn) (int32, error) { + md, exists := metadata.FromOutgoingContext(ctx) + require.True(t, exists) + require.Equal(t, len(tt.want), md.Len()) - for k, v := range tt.want { - values := md.Get(k) - require.Equal(t, 1, len(values)) - require.Equal(t, v, values[0]) - } + for k, v := range tt.want { + values := md.Get(k) + require.Equal(t, 1, len(values)) + require.Equal(t, v, values[0]) + } + return 0, nil + }) + + require.NoError(t, err) }) } } |