summaryrefslogtreecommitdiff
path: root/internal/handler
diff options
context:
space:
mode:
Diffstat (limited to 'internal/handler')
-rw-r--r--internal/handler/exec.go89
-rw-r--r--internal/handler/exec_test.go26
2 files changed, 52 insertions, 63 deletions
diff --git a/internal/handler/exec.go b/internal/handler/exec.go
index 723d655..1a7716e 100644
--- a/internal/handler/exec.go
+++ b/internal/handler/exec.go
@@ -6,20 +6,21 @@ import (
"strconv"
"strings"
+ grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
+ grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
log "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
gitalyauth "gitlab.com/gitlab-org/gitaly/auth"
"gitlab.com/gitlab-org/gitaly/client"
pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/accessverifier"
"gitlab.com/gitlab-org/gitlab-shell/internal/sshenv"
"gitlab.com/gitlab-org/labkit/correlation"
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
- "gitlab.com/gitlab-org/labkit/tracing"
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
)
// GitalyHandlerFunc implementations are responsible for making
@@ -27,12 +28,6 @@ import (
// and returning an error from the Gitaly call.
type GitalyHandlerFunc func(ctx context.Context, client *grpc.ClientConn) (int32, error)
-type GitalyConn struct {
- ctx context.Context
- conn *grpc.ClientConn
- close func()
-}
-
type GitalyCommand struct {
Config *config.Config
ServiceName string
@@ -45,14 +40,14 @@ type GitalyCommand struct {
// through GitLab-Shell. It ensures that logging, tracing and other
// common concerns are configured before executing the `handler`.
func (gc *GitalyCommand) RunGitalyCommand(ctx context.Context, handler GitalyHandlerFunc) error {
- gitalyConn, err := getConn(ctx, gc)
+ conn, err := getConn(ctx, gc)
if err != nil {
return err
}
+ defer conn.Close()
- defer gitalyConn.close()
-
- _, err = handler(gitalyConn.ctx, gitalyConn.conn)
+ childCtx := withOutgoingMetadata(ctx, gc.Features)
+ _, err = handler(childCtx, conn)
return err
}
@@ -106,23 +101,43 @@ func withOutgoingMetadata(ctx context.Context, features map[string]string) conte
return metadata.NewOutgoingContext(ctx, md)
}
-func getConn(ctx context.Context, gc *GitalyCommand) (*GitalyConn, error) {
+func getConn(ctx context.Context, gc *GitalyCommand) (*grpc.ClientConn, error) {
if gc.Address == "" {
return nil, fmt.Errorf("no gitaly_address given")
}
+ serviceName := correlation.ExtractClientNameFromContext(ctx)
+ if serviceName == "" {
+ log.Warn("No gRPC service name specified, defaulting to gitlab-shell-unknown")
+
+ serviceName = "gitlab-shell-unknown"
+ }
+
+ serviceName = fmt.Sprintf("%s-%s", serviceName, gc.ServiceName)
+
connOpts := client.DefaultDialOpts
- connOpts = append(connOpts,
+ connOpts = append(
+ connOpts,
grpc.WithStreamInterceptor(
- grpccorrelation.StreamClientCorrelationInterceptor(
- grpccorrelation.WithClientName(executable.GitlabShell),
+ grpc_middleware.ChainStreamClient(
+ grpctracing.StreamClientTracingInterceptor(),
+ grpc_prometheus.StreamClientInterceptor,
+ grpccorrelation.StreamClientCorrelationInterceptor(
+ grpccorrelation.WithClientName(serviceName),
+ ),
),
),
+
grpc.WithUnaryInterceptor(
- grpccorrelation.UnaryClientCorrelationInterceptor(
- grpccorrelation.WithClientName(executable.GitlabShell),
+ grpc_middleware.ChainUnaryClient(
+ grpctracing.UnaryClientTracingInterceptor(),
+ grpc_prometheus.UnaryClientInterceptor,
+ grpccorrelation.UnaryClientCorrelationInterceptor(
+ grpccorrelation.WithClientName(serviceName),
+ ),
),
- ))
+ ),
+ )
if gc.Token != "" {
connOpts = append(connOpts,
@@ -130,35 +145,5 @@ func getConn(ctx context.Context, gc *GitalyCommand) (*GitalyConn, error) {
)
}
- // Configure distributed tracing
- serviceName := fmt.Sprintf("gitlab-shell-%v", gc.ServiceName)
- closer := tracing.Initialize(
- tracing.WithServiceName(serviceName),
-
- // For GitLab-Shell, we explicitly initialize tracing from a config file
- // instead of the default environment variable (using GITLAB_TRACING)
- // This decision was made owing to the difficulty in passing environment
- // variables into GitLab-Shell processes.
- //
- // Processes are spawned as children of the SSH daemon, which tightly
- // controls environment variables; doing this means we don't have to
- // enable PermitUserEnvironment
- tracing.WithConnectionString(gc.Config.GitlabTracing),
- )
-
- childCtx, finished := tracing.ExtractFromEnv(ctx)
- childCtx = withOutgoingMetadata(childCtx, gc.Features)
-
- conn, err := client.DialContext(childCtx, gc.Address, connOpts)
- if err != nil {
- return nil, err
- }
-
- finish := func() {
- finished()
- closer.Close()
- conn.Close()
- }
-
- return &GitalyConn{ctx: childCtx, conn: conn, close: finish}, nil
+ return client.DialContext(ctx, gc.Address, connOpts)
}
diff --git a/internal/handler/exec_test.go b/internal/handler/exec_test.go
index 6f84709..9b8fee8 100644
--- a/internal/handler/exec_test.go
+++ b/internal/handler/exec_test.go
@@ -45,7 +45,7 @@ func TestMissingGitalyAddress(t *testing.T) {
require.EqualError(t, err, "no gitaly_address given")
}
-func TestGetConnMetadata(t *testing.T) {
+func TestRunGitalyCommandMetadata(t *testing.T) {
tests := []struct {
name string
gc *GitalyCommand
@@ -70,19 +70,23 @@ func TestGetConnMetadata(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- conn, err := getConn(context.Background(), tt.gc)
- require.NoError(t, err)
+ cmd := tt.gc
- md, exists := metadata.FromOutgoingContext(conn.ctx)
- require.True(t, exists)
- require.Equal(t, len(tt.want), md.Len())
+ err := cmd.RunGitalyCommand(context.Background(), func(ctx context.Context, _ *grpc.ClientConn) (int32, error) {
+ md, exists := metadata.FromOutgoingContext(ctx)
+ require.True(t, exists)
+ require.Equal(t, len(tt.want), md.Len())
- for k, v := range tt.want {
- values := md.Get(k)
- require.Equal(t, 1, len(values))
- require.Equal(t, v, values[0])
- }
+ for k, v := range tt.want {
+ values := md.Get(k)
+ require.Equal(t, 1, len(values))
+ require.Equal(t, v, values[0])
+ }
+ return 0, nil
+ })
+
+ require.NoError(t, err)
})
}
}