diff options
Diffstat (limited to 'client')
-rw-r--r-- | client/client_test.go | 240 | ||||
-rw-r--r-- | client/gitlabnet.go | 140 | ||||
-rw-r--r-- | client/httpclient.go | 113 | ||||
-rw-r--r-- | client/httpclient_test.go | 105 | ||||
-rw-r--r-- | client/httpsclient_test.go | 115 | ||||
-rw-r--r-- | client/testserver/gitalyserver.go | 85 | ||||
-rw-r--r-- | client/testserver/testserver.go | 81 |
7 files changed, 879 insertions, 0 deletions
diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..dfb1ca3 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,240 @@ +package client + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "path" + "strings" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestClients(t *testing.T) { + testDirCleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer testDirCleanup() + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/hello", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + + fmt.Fprint(w, "Hello") + }, + }, + { + Path: "/api/v4/internal/post_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + fmt.Fprint(w, "Echo: "+string(b)) + }, + }, + { + Path: "/api/v4/internal/auth", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, r.Header.Get(secretHeaderName)) + }, + }, + { + Path: "/api/v4/internal/error", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + body := map[string]string{ + "message": "Don't do that", + } + json.NewEncoder(w).Encode(body) + }, + }, + { + Path: "/api/v4/internal/broken", + Handler: func(w http.ResponseWriter, r *http.Request) { + panic("Broken") + }, + }, + } + + testCases := []struct { + desc string + caFile string + server func(*testing.T, []testserver.TestRequestHandler) (string, func()) + }{ + { + desc: "Socket client", + server: testserver.StartSocketHttpServer, + }, + { + desc: "Http client", + server: testserver.StartHttpServer, + }, + { + desc: "Https client", + caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), + server: testserver.StartHttpsServer, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + url, cleanup := tc.server(t, requests) + defer cleanup() + + secret := "sssh, it's a secret" + + httpClient := NewHTTPClient(url, tc.caFile, "", false, 1) + + client, err := NewGitlabNetClient("", "", secret, httpClient) + require.NoError(t, err) + + testBrokenRequest(t, client) + testSuccessfulGet(t, client) + testSuccessfulPost(t, client) + testMissing(t, client) + testErrorMessage(t, client) + testAuthenticationHeader(t, client) + }) + } +} + +func testSuccessfulGet(t *testing.T, client *GitlabNetClient) { + t.Run("Successful get", func(t *testing.T) { + hook := testhelper.SetupLogger() + response, err := client.Get("/hello") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, string(responseBody), "Hello") + + assert.Equal(t, 1, len(hook.Entries)) + assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) + assert.True(t, strings.Contains(hook.LastEntry().Message, "method=GET")) + assert.True(t, strings.Contains(hook.LastEntry().Message, "Finished HTTP request")) + }) +} + +func testSuccessfulPost(t *testing.T, client *GitlabNetClient) { + t.Run("Successful Post", func(t *testing.T) { + hook := testhelper.SetupLogger() + data := map[string]string{"key": "value"} + + response, err := client.Post("/post_endpoint", data) + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody)) + + assert.Equal(t, 1, len(hook.Entries)) + assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) + assert.True(t, strings.Contains(hook.LastEntry().Message, "method=POST")) + assert.True(t, strings.Contains(hook.LastEntry().Message, "Finished HTTP request")) + }) +} + +func testMissing(t *testing.T, client *GitlabNetClient) { + t.Run("Missing error for GET", func(t *testing.T) { + hook := testhelper.SetupLogger() + response, err := client.Get("/missing") + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + + assert.Equal(t, 1, len(hook.Entries)) + assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) + assert.True(t, strings.Contains(hook.LastEntry().Message, "method=GET")) + assert.True(t, strings.Contains(hook.LastEntry().Message, "Internal API error")) + }) + + t.Run("Missing error for POST", func(t *testing.T) { + hook := testhelper.SetupLogger() + response, err := client.Post("/missing", map[string]string{}) + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + + assert.Equal(t, 1, len(hook.Entries)) + assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level) + assert.True(t, strings.Contains(hook.LastEntry().Message, "method=POST")) + assert.True(t, strings.Contains(hook.LastEntry().Message, "Internal API error")) + }) +} + +func testErrorMessage(t *testing.T, client *GitlabNetClient) { + t.Run("Error with message for GET", func(t *testing.T) { + response, err := client.Get("/error") + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) + + t.Run("Error with message for POST", func(t *testing.T) { + response, err := client.Post("/error", map[string]string{}) + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) +} + +func testBrokenRequest(t *testing.T, client *GitlabNetClient) { + t.Run("Broken request for GET", func(t *testing.T) { + response, err := client.Get("/broken") + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) + + t.Run("Broken request for POST", func(t *testing.T) { + response, err := client.Post("/broken", map[string]string{}) + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) +} + +func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) { + t.Run("Authentication headers for GET", func(t *testing.T) { + response, err := client.Get("/auth") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) + + t.Run("Authentication headers for POST", func(t *testing.T) { + response, err := client.Post("/auth", map[string]string{}) + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) +} diff --git a/client/gitlabnet.go b/client/gitlabnet.go new file mode 100644 index 0000000..67c48c7 --- /dev/null +++ b/client/gitlabnet.go @@ -0,0 +1,140 @@ +package client + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + internalApiPath = "/api/v4/internal" + secretHeaderName = "Gitlab-Shared-Secret" +) + +type ErrorResponse struct { + Message string `json:"message"` +} + +type GitlabNetClient struct { + httpClient *HttpClient + user, password, secret string +} + +func NewGitlabNetClient( + user, + password, + secret string, + httpClient *HttpClient, +) (*GitlabNetClient, error) { + + if httpClient == nil { + return nil, fmt.Errorf("Unsupported protocol") + } + + return &GitlabNetClient{ + httpClient: httpClient, + user: user, + password: password, + secret: secret, + }, nil +} + +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + if !strings.HasPrefix(path, internalApiPath) { + path = internalApiPath + path + } + return path +} + +func newRequest(method, host, path string, data interface{}) (*http.Request, error) { + var jsonReader io.Reader + if data != nil { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + jsonReader = bytes.NewReader(jsonData) + } + + request, err := http.NewRequest(method, host+path, jsonReader) + if err != nil { + return nil, err + } + + return request, nil +} + +func parseError(resp *http.Response) error { + if resp.StatusCode >= 200 && resp.StatusCode <= 399 { + return nil + } + defer resp.Body.Close() + parsedResponse := &ErrorResponse{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return fmt.Errorf("Internal API error (%v)", resp.StatusCode) + } else { + return fmt.Errorf(parsedResponse.Message) + } + +} + +func (c *GitlabNetClient) Get(path string) (*http.Response, error) { + return c.DoRequest(http.MethodGet, normalizePath(path), nil) +} + +func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) { + return c.DoRequest(http.MethodPost, normalizePath(path), data) +} + +func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) { + request, err := newRequest(method, c.httpClient.Host, path, data) + if err != nil { + return nil, err + } + + user, password := c.user, c.password + if user != "" && password != "" { + request.SetBasicAuth(user, password) + } + + encodedSecret := base64.StdEncoding.EncodeToString([]byte(c.secret)) + request.Header.Set(secretHeaderName, encodedSecret) + + request.Header.Add("Content-Type", "application/json") + request.Close = true + + start := time.Now() + response, err := c.httpClient.Do(request) + fields := log.Fields{ + "method": method, + "url": request.URL.String(), + "duration_ms": time.Since(start) / time.Millisecond, + } + + if err != nil { + log.WithError(err).WithFields(fields).Error("Internal API unreachable") + return nil, fmt.Errorf("Internal API unreachable") + } + + if err := parseError(response); err != nil { + log.WithError(err).WithFields(fields).Error("Internal API error") + return nil, err + } + + log.WithFields(fields).Info("Finished HTTP request") + + return response, nil +} diff --git a/client/httpclient.go b/client/httpclient.go new file mode 100644 index 0000000..ff0cc25 --- /dev/null +++ b/client/httpclient.go @@ -0,0 +1,113 @@ +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "net/http" + "path/filepath" + "strings" + "time" +) + +const ( + socketBaseUrl = "http://unix" + unixSocketProtocol = "http+unix://" + httpProtocol = "http://" + httpsProtocol = "https://" + defaultReadTimeoutSeconds = 300 +) + +type HttpClient struct { + *http.Client + Host string +} + +func NewHTTPClient(gitlabURL, caFile, caPath string, selfSignedCert bool, readTimeoutSeconds uint64) *HttpClient { + + var transport *http.Transport + var host string + if strings.HasPrefix(gitlabURL, unixSocketProtocol) { + transport, host = buildSocketTransport(gitlabURL) + } else if strings.HasPrefix(gitlabURL, httpProtocol) { + transport, host = buildHttpTransport(gitlabURL) + } else if strings.HasPrefix(gitlabURL, httpsProtocol) { + transport, host = buildHttpsTransport(caFile, caPath, selfSignedCert, gitlabURL) + } else { + return nil + } + + c := &http.Client{ + Transport: transport, + Timeout: readTimeout(readTimeoutSeconds), + } + + client := &HttpClient{Client: c, Host: host} + + return client +} + +func buildSocketTransport(gitlabURL string) (*http.Transport, string) { + socketPath := strings.TrimPrefix(gitlabURL, unixSocketProtocol) + transport := &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", socketPath) + }, + } + + return transport, socketBaseUrl +} + +func buildHttpsTransport(caFile, caPath string, selfSignedCert bool, gitlabURL string) (*http.Transport, string) { + certPool, err := x509.SystemCertPool() + + if err != nil { + certPool = x509.NewCertPool() + } + + if caFile != "" { + addCertToPool(certPool, caFile) + } + + if caPath != "" { + fis, _ := ioutil.ReadDir(caPath) + for _, fi := range fis { + if fi.IsDir() { + continue + } + + addCertToPool(certPool, filepath.Join(caPath, fi.Name())) + } + } + + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + InsecureSkipVerify: selfSignedCert, + }, + } + + return transport, gitlabURL +} + +func addCertToPool(certPool *x509.CertPool, fileName string) { + cert, err := ioutil.ReadFile(fileName) + if err == nil { + certPool.AppendCertsFromPEM(cert) + } +} + +func buildHttpTransport(gitlabURL string) (*http.Transport, string) { + return &http.Transport{}, gitlabURL +} + +func readTimeout(timeoutSeconds uint64) time.Duration { + if timeoutSeconds == 0 { + timeoutSeconds = defaultReadTimeoutSeconds + } + + return time.Duration(timeoutSeconds) * time.Second +} diff --git a/client/httpclient_test.go b/client/httpclient_test.go new file mode 100644 index 0000000..1f0a4ed --- /dev/null +++ b/client/httpclient_test.go @@ -0,0 +1,105 @@ +package client + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" +) + +func TestReadTimeout(t *testing.T) { + expectedSeconds := uint64(300) + + client := NewHTTPClient("http://localhost:3000", "", "", false, expectedSeconds) + + require.NotNil(t, client) + assert.Equal(t, time.Duration(expectedSeconds)*time.Second, client.Client.Timeout) +} + +const ( + username = "basic_auth_user" + password = "basic_auth_password" +) + +func TestBasicAuthSettings(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/get_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + + fmt.Fprint(w, r.Header.Get("Authorization")) + }, + }, + { + Path: "/api/v4/internal/post_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + fmt.Fprint(w, r.Header.Get("Authorization")) + }, + }, + } + + client, cleanup := setup(t, username, password, requests) + defer cleanup() + + response, err := client.Get("/get_endpoint") + require.NoError(t, err) + testBasicAuthHeaders(t, response) + + response, err = client.Post("/post_endpoint", nil) + require.NoError(t, err) + testBasicAuthHeaders(t, response) +} + +func testBasicAuthHeaders(t *testing.T, response *http.Response) { + defer response.Body.Close() + + require.NotNil(t, response) + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + + headerParts := strings.Split(string(responseBody), " ") + assert.Equal(t, "Basic", headerParts[0]) + + credentials, err := base64.StdEncoding.DecodeString(headerParts[1]) + require.NoError(t, err) + + assert.Equal(t, username+":"+password, string(credentials)) +} + +func TestEmptyBasicAuthSettings(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/empty_basic_auth", + Handler: func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "", r.Header.Get("Authorization")) + }, + }, + } + + client, cleanup := setup(t, "", "", requests) + defer cleanup() + + _, err := client.Get("/empty_basic_auth") + require.NoError(t, err) +} + +func setup(t *testing.T, username, password string, requests []testserver.TestRequestHandler) (*GitlabNetClient, func()) { + url, cleanup := testserver.StartHttpServer(t, requests) + + httpClient := NewHTTPClient(url, "", "", false, 1) + + client, err := NewGitlabNetClient(username, password, "", httpClient) + require.NoError(t, err) + + return client, cleanup +} diff --git a/client/httpsclient_test.go b/client/httpsclient_test.go new file mode 100644 index 0000000..6c3ae08 --- /dev/null +++ b/client/httpsclient_test.go @@ -0,0 +1,115 @@ +package client + +import ( + "fmt" + "io/ioutil" + "net/http" + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +func TestSuccessfulRequests(t *testing.T) { + testCases := []struct { + desc string + caFile, caPath string + selfSigned bool + }{ + { + desc: "Valid CaFile", + caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), + }, + { + desc: "Valid CaPath", + caPath: path.Join(testhelper.TestRoot, "certs/valid"), + }, + { + desc: "Self signed cert option enabled", + selfSigned: true, + }, + { + desc: "Invalid cert with self signed cert option enabled", + caFile: path.Join(testhelper.TestRoot, "certs/valid/server.crt"), + selfSigned: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned) + defer cleanup() + + response, err := client.Get("/hello") + require.NoError(t, err) + require.NotNil(t, response) + + defer response.Body.Close() + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, string(responseBody), "Hello") + }) + } +} + +func TestFailedRequests(t *testing.T) { + testCases := []struct { + desc string + caFile string + caPath string + }{ + { + desc: "Invalid CaFile", + caFile: path.Join(testhelper.TestRoot, "certs/invalid/server.crt"), + }, + { + desc: "Invalid CaPath", + caPath: path.Join(testhelper.TestRoot, "certs/invalid"), + }, + { + desc: "Empty config", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false) + defer cleanup() + + _, err := client.Get("/hello") + require.Error(t, err) + + assert.Equal(t, err.Error(), "Internal API unreachable") + }) + } +} + +func setupWithRequests(t *testing.T, caFile, caPath string, selfSigned bool) (*GitlabNetClient, func()) { + testDirCleanup, err := testhelper.PrepareTestRootDir() + require.NoError(t, err) + defer testDirCleanup() + + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/hello", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + + fmt.Fprint(w, "Hello") + }, + }, + } + + url, cleanup := testserver.StartHttpsServer(t, requests) + + httpClient := NewHTTPClient(url, caFile, caPath, selfSigned, 1) + + client, err := NewGitlabNetClient("", "", "", httpClient) + require.NoError(t, err) + + return client, cleanup +} diff --git a/client/testserver/gitalyserver.go b/client/testserver/gitalyserver.go new file mode 100644 index 0000000..4bf14f3 --- /dev/null +++ b/client/testserver/gitalyserver.go @@ -0,0 +1,85 @@ +package testserver + +import ( + "io/ioutil" + "net" + "os" + "path" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type TestGitalyServer struct{ ReceivedMD metadata.MD } + +func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + + s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context()) + + response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository) + stream.Send(&pb.SSHReceivePackResponse{Stdout: response}) + + return nil +} + +func (s *TestGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + + s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context()) + + response := []byte("UploadPack: " + req.Repository.GlRepository) + stream.Send(&pb.SSHUploadPackResponse{Stdout: response}) + + return nil +} + +func (s *TestGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error { + req, err := stream.Recv() + if err != nil { + return err + } + + s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context()) + + response := []byte("UploadArchive: " + req.Repository.GlRepository) + stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response}) + + return nil +} + +func StartGitalyServer(t *testing.T) (string, *TestGitalyServer, func()) { + tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api") + gitalySocketPath := path.Join(tempDir, "gitaly.sock") + + err := os.MkdirAll(filepath.Dir(gitalySocketPath), 0700) + require.NoError(t, err) + + server := grpc.NewServer() + + listener, err := net.Listen("unix", gitalySocketPath) + require.NoError(t, err) + + testServer := TestGitalyServer{} + pb.RegisterSSHServiceServer(server, &testServer) + + go server.Serve(listener) + + gitalySocketUrl := "unix:" + gitalySocketPath + cleanup := func() { + server.Stop() + os.RemoveAll(tempDir) + } + + return gitalySocketUrl, &testServer, cleanup +} diff --git a/client/testserver/testserver.go b/client/testserver/testserver.go new file mode 100644 index 0000000..377e331 --- /dev/null +++ b/client/testserver/testserver.go @@ -0,0 +1,81 @@ +package testserver + +import ( + "crypto/tls" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "os" + "path" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper" +) + +var ( + tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api") + testSocket = path.Join(tempDir, "internal.sock") +) + +type TestRequestHandler struct { + Path string + Handler func(w http.ResponseWriter, r *http.Request) +} + +func StartSocketHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + err := os.MkdirAll(filepath.Dir(testSocket), 0700) + require.NoError(t, err) + + socketListener, err := net.Listen("unix", testSocket) + require.NoError(t, err) + + server := http.Server{ + Handler: buildHandler(handlers), + // We'll put this server through some nasty stuff we don't want + // in our test output + ErrorLog: log.New(ioutil.Discard, "", 0), + } + go server.Serve(socketListener) + + url := "http+unix://" + testSocket + + return url, cleanupSocket +} + +func StartHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + server := httptest.NewServer(buildHandler(handlers)) + + return server.URL, server.Close +} + +func StartHttpsServer(t *testing.T, handlers []TestRequestHandler) (string, func()) { + crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt") + key := path.Join(testhelper.TestRoot, "certs/valid/server.key") + + server := httptest.NewUnstartedServer(buildHandler(handlers)) + cer, err := tls.LoadX509KeyPair(crt, key) + require.NoError(t, err) + + server.TLS = &tls.Config{Certificates: []tls.Certificate{cer}} + server.StartTLS() + + return server.URL, server.Close +} + +func cleanupSocket() { + os.RemoveAll(tempDir) +} + +func buildHandler(handlers []TestRequestHandler) http.Handler { + h := http.NewServeMux() + + for _, handler := range handlers { + h.HandleFunc(handler.Path, handler.Handler) + } + + return h +} |