diff options
Diffstat (limited to 'internal/command/command_test.go')
-rw-r--r-- | internal/command/command_test.go | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/internal/command/command_test.go b/internal/command/command_test.go index db55e7d..9160abf 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -2,6 +2,7 @@ package command import ( "errors" + "os" "testing" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/internal/config" "gitlab.com/gitlab-org/gitlab-shell/internal/executable" "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" + "gitlab.com/gitlab-org/labkit/correlation" ) var ( @@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) { }) } } + +func TestContextWithCorrelationID(t *testing.T) { + testCases := []struct { + name string + additionalEnv map[string]string + expectedCorrelationID string + }{ + { + name: "no CORRELATION_ID in environment", + }, + { + name: "CORRELATION_ID in environment", + additionalEnv: map[string]string{ + "CORRELATION_ID": "abc123", + }, + expectedCorrelationID: "abc123", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resetEnvironment := addAdditionalEnv(tc.additionalEnv) + defer resetEnvironment() + + ctx, finished := ContextWithCorrelationID() + require.NotNil(t, ctx, "ctx is nil") + require.NotNil(t, finished, "finished is nil") + correlationID := correlation.ExtractFromContext(ctx) + require.NotEmpty(t, correlationID) + + if tc.expectedCorrelationID != "" { + require.Equal(t, tc.expectedCorrelationID, correlationID) + } + defer finished() + }) + } +} + +// addAdditionalEnv will configure additional environment values +// and return a deferrable function to reset the environment to +// it's original state after the test +func addAdditionalEnv(envMap map[string]string) func() { + prevValues := map[string]string{} + unsetValues := []string{} + for k, v := range envMap { + value, exists := os.LookupEnv(k) + if exists { + prevValues[k] = value + } else { + unsetValues = append(unsetValues, k) + } + os.Setenv(k, v) + } + + return func() { + for k, v := range prevValues { + os.Setenv(k, v) + } + + for _, k := range unsetValues { + os.Unsetenv(k) + } + + } +} |