summaryrefslogtreecommitdiff
path: root/go/internal/gitlabnet
diff options
context:
space:
mode:
Diffstat (limited to 'go/internal/gitlabnet')
-rw-r--r--go/internal/gitlabnet/client.go77
-rw-r--r--go/internal/gitlabnet/client_test.go131
-rw-r--r--go/internal/gitlabnet/discover/client.go76
-rw-r--r--go/internal/gitlabnet/discover/client_test.go131
-rw-r--r--go/internal/gitlabnet/socketclient.go46
-rw-r--r--go/internal/gitlabnet/testserver/testserver.go56
6 files changed, 517 insertions, 0 deletions
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
new file mode 100644
index 0000000..abc218f
--- /dev/null
+++ b/go/internal/gitlabnet/client.go
@@ -0,0 +1,77 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+)
+
+const (
+ internalApiPath = "/api/v4/internal"
+ secretHeaderName = "Gitlab-Shared-Secret"
+)
+
+type GitlabClient interface {
+ Get(path string) (*http.Response, error)
+ // TODO: implement posts
+ // Post(path string) (http.Response, error)
+}
+
+type ErrorResponse struct {
+ Message string `json:"message"`
+}
+
+func GetClient(config *config.Config) (GitlabClient, error) {
+ url := config.GitlabUrl
+ if strings.HasPrefix(url, UnixSocketProtocol) {
+ return buildSocketClient(config), nil
+ }
+
+ return nil, fmt.Errorf("Unsupported protocol")
+}
+
+func normalizePath(path string) string {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+
+ if !strings.HasPrefix(path, internalApiPath) {
+ path = internalApiPath + path
+ }
+ return path
+}
+
+func parseError(resp *http.Response) error {
+ if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
+ return nil
+ }
+ defer resp.Body.Close()
+ parsedResponse := &ErrorResponse{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
+ } else {
+ return fmt.Errorf(parsedResponse.Message)
+ }
+
+}
+
+func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
+ encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
+ request.Header.Set(secretHeaderName, encodedSecret)
+
+ response, err := client.Do(request)
+ if err != nil {
+ return nil, fmt.Errorf("Internal API unreachable")
+ }
+
+ if err := parseError(response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
new file mode 100644
index 0000000..f69f284
--- /dev/null
+++ b/go/internal/gitlabnet/client_test.go
@@ -0,0 +1,131 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+func TestClients(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ {
+ Path: "/api/v4/internal/auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, r.Header.Get(secretHeaderName))
+ },
+ },
+ {
+ Path: "/api/v4/internal/error",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ body := map[string]string{
+ "message": "Don't do that",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ {
+ Path: "/api/v4/internal/broken",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ panic("Broken")
+ },
+ },
+ }
+ testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
+
+ testCases := []struct {
+ desc string
+ client GitlabClient
+ server func([]testserver.TestRequestHandler) (func(), error)
+ }{
+ {
+ desc: "Socket client",
+ client: buildSocketClient(testConfig),
+ server: testserver.StartSocketHttpServer,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ cleanup, err := tc.server(requests)
+ defer cleanup()
+ require.NoError(t, err)
+
+ testBrokenRequest(t, tc.client)
+ testSuccessfulGet(t, tc.client)
+ testMissing(t, tc.client)
+ testErrorMessage(t, tc.client)
+ testAuthenticationHeader(t, tc.client)
+ })
+ }
+}
+
+func testSuccessfulGet(t *testing.T, client GitlabClient) {
+ t.Run("Successful get", func(t *testing.T) {
+ response, err := client.Get("/hello")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+ })
+}
+
+func testMissing(t *testing.T, client GitlabClient) {
+ t.Run("Missing error", func(t *testing.T) {
+ response, err := client.Get("/missing")
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
+}
+
+func testErrorMessage(t *testing.T, client GitlabClient) {
+ t.Run("Error with message", func(t *testing.T) {
+ response, err := client.Get("/error")
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+}
+
+func testBrokenRequest(t *testing.T, client GitlabClient) {
+ t.Run("Broken request", func(t *testing.T) {
+ response, err := client.Get("/broken")
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+}
+
+func testAuthenticationHeader(t *testing.T, client GitlabClient) {
+ t.Run("Authentication headers", func(t *testing.T) {
+ response, err := client.Get("/auth")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+}
diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go
new file mode 100644
index 0000000..8df78fb
--- /dev/null
+++ b/go/internal/gitlabnet/discover/client.go
@@ -0,0 +1,76 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+)
+
+type Client struct {
+ config *config.Config
+ client gitlabnet.GitlabClient
+}
+
+type Response struct {
+ UserId int64 `json:"id"`
+ Name string `json:"name"`
+ Username string `json:"username"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetByKeyId(keyId string) (*Response, error) {
+ params := url.Values{}
+ params.Add("key_id", keyId)
+
+ return c.getResponse(params)
+}
+
+func (c *Client) GetByUsername(username string) (*Response, error) {
+ params := url.Values{}
+ params.Add("username", username)
+
+ return c.getResponse(params)
+}
+
+func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
+ parsedResponse := &Response{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return nil, err
+ } else {
+ return parsedResponse, nil
+ }
+}
+
+func (c *Client) getResponse(params url.Values) (*Response, error) {
+ path := "/discover?" + params.Encode()
+ response, err := c.client.Get(path)
+
+ if err != nil {
+ return nil, err
+ }
+
+ defer response.Body.Close()
+ parsedResponse, err := c.parseResponse(response)
+ if err != nil {
+ return nil, fmt.Errorf("Parsing failed")
+ }
+
+ return parsedResponse, nil
+}
+
+func (r *Response) IsAnonymous() bool {
+ return r.UserId < 1
+}
diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go
new file mode 100644
index 0000000..e88cedd
--- /dev/null
+++ b/go/internal/gitlabnet/discover/client_test.go
@@ -0,0 +1,131 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func init() {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" {
+ body := &Response{
+ UserId: 2,
+ Username: "alex-doe",
+ Name: "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "jane-doe" {
+ body := &Response{
+ UserId: 1,
+ Username: "jane-doe",
+ Name: "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_message" {
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_json" {
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ } else if r.URL.Query().Get("username") == "broken_empty" {
+ w.WriteHeader(http.StatusForbidden)
+ } else {
+ fmt.Fprint(w, "null")
+ }
+ },
+ },
+ }
+}
+
+func TestGetByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByKeyId("1")
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
+}
+
+func TestGetByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByUsername("jane-doe")
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByUsername("missing")
+ assert.NoError(t, err)
+ assert.True(t, result.IsAnonymous())
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeUsername string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeUsername: "broken_message",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeUsername: "broken_json",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeUsername: "broken_empty",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ resp, err := client.GetByUsername(tc.fakeUsername)
+
+ assert.EqualError(t, err, tc.expectedError)
+ assert.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+
+ client, err := NewClient(testConfig)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go
new file mode 100644
index 0000000..3bd7c70
--- /dev/null
+++ b/go/internal/gitlabnet/socketclient.go
@@ -0,0 +1,46 @@
+package gitlabnet
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+)
+
+const (
+ // We need to set the base URL to something starting with HTTP, the host
+ // itself is ignored as we're talking over a socket.
+ socketBaseUrl = "http://unix"
+ UnixSocketProtocol = "http+unix://"
+)
+
+type GitlabSocketClient struct {
+ httpClient *http.Client
+ config *config.Config
+}
+
+func buildSocketClient(config *config.Config) *GitlabSocketClient {
+ path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", path)
+ },
+ },
+ }
+
+ return &GitlabSocketClient{httpClient: httpClient, config: config}
+}
+
+func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
+ path = normalizePath(path)
+
+ request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return doRequest(c.httpClient, c.config, request)
+}
diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go
new file mode 100644
index 0000000..9640fd7
--- /dev/null
+++ b/go/internal/gitlabnet/testserver/testserver.go
@@ -0,0 +1,56 @@
+package testserver
+
+import (
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+)
+
+var (
+ tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
+ TestSocket = path.Join(tempDir, "internal.sock")
+)
+
+type TestRequestHandler struct {
+ Path string
+ Handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
+ if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
+ return nil, err
+ }
+
+ socketListener, err := net.Listen("unix", TestSocket)
+ if err != nil {
+ return nil, err
+ }
+
+ server := http.Server{
+ Handler: buildHandler(handlers),
+ // We'll put this server through some nasty stuff we don't want
+ // in our test output
+ ErrorLog: log.New(ioutil.Discard, "", 0),
+ }
+ go server.Serve(socketListener)
+
+ return cleanupSocket, nil
+}
+
+func cleanupSocket() {
+ os.RemoveAll(tempDir)
+}
+
+func buildHandler(handlers []TestRequestHandler) http.Handler {
+ h := http.NewServeMux()
+
+ for _, handler := range handlers {
+ h.HandleFunc(handler.Path, handler.Handler)
+ }
+
+ return h
+}