summaryrefslogtreecommitdiff
path: root/go/internal
diff options
context:
space:
mode:
Diffstat (limited to 'go/internal')
-rw-r--r--go/internal/command/discover/discover.go37
-rw-r--r--go/internal/command/discover/discover_test.go130
-rw-r--r--go/internal/gitlabnet/client.go77
-rw-r--r--go/internal/gitlabnet/client_test.go131
-rw-r--r--go/internal/gitlabnet/discover/client.go72
-rw-r--r--go/internal/gitlabnet/discover/client_test.go86
-rw-r--r--go/internal/gitlabnet/socketclient.go46
-rw-r--r--go/internal/gitlabnet/testserver/testserver.go56
8 files changed, 634 insertions, 1 deletions
diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go
index 63a7a32..ab04cbd 100644
--- a/go/internal/command/discover/discover.go
+++ b/go/internal/command/discover/discover.go
@@ -2,9 +2,12 @@ package discover
import (
"fmt"
+ "io"
+ "os"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
type Command struct {
@@ -12,6 +15,38 @@ type Command struct {
Args *commandargs.CommandArgs
}
+var (
+ output io.Writer = os.Stdout
+)
+
func (c *Command) Execute() error {
- return fmt.Errorf("No feature is implemented yet")
+ response, err := c.getUserInfo()
+ if err != nil {
+ return fmt.Errorf("Failed to get username: %v", err)
+ }
+
+ if response.IsAnonymous() {
+ fmt.Fprintf(output, "Welcome to GitLab, Anonymous!\n")
+ } else {
+ fmt.Fprintf(output, "Welcome to GitLab, @%s!\n", response.Username)
+ }
+
+ return nil
+}
+
+func (c *Command) getUserInfo() (*discover.Response, error) {
+ client, err := discover.NewClient(c.Config)
+ if err != nil {
+ return nil, err
+ }
+
+ if c.Args.GitlabKeyId != "" {
+ return client.GetByKeyId(c.Args.GitlabKeyId)
+ } else if c.Args.GitlabUsername != "" {
+ return client.GetByUsername(c.Args.GitlabUsername)
+ } else {
+ // There was no 'who' information, this matches the ruby error
+ // message.
+ return nil, fmt.Errorf("who='' is invalid")
+ }
}
diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go
new file mode 100644
index 0000000..752e76e
--- /dev/null
+++ b/go/internal/command/discover/discover_test.go
@@ -0,0 +1,130 @@
+package discover
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+var (
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" {
+ body := map[string]interface{}{
+ "id": 2,
+ "username": "alex-doe",
+ "name": "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken_message" {
+ body := map[string]string{
+ "message": "Forbidden!",
+ }
+ w.WriteHeader(http.StatusForbidden)
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "broken" {
+ w.WriteHeader(http.StatusInternalServerError)
+ } else {
+ fmt.Fprint(w, "null")
+ }
+ },
+ },
+ }
+)
+
+func TestExecute(t *testing.T) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.CommandArgs
+ expectedOutput string
+ }{
+ {
+ desc: "With a known username",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "alex-doe"},
+ expectedOutput: "Welcome to GitLab, @alex-doe!\n",
+ },
+ {
+ desc: "With a known key id",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
+ expectedOutput: "Welcome to GitLab, @alex-doe!\n",
+ },
+ {
+ desc: "With an unknown key",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
+ expectedOutput: "Welcome to GitLab, Anonymous!\n",
+ },
+ {
+ desc: "With an unknown username",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "unknown"},
+ expectedOutput: "Welcome to GitLab, Anonymous!\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ buffer := &bytes.Buffer{}
+ output = buffer
+ cmd := &Command{Config: testConfig, Args: tc.arguments}
+
+ err := cmd.Execute()
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expectedOutput, buffer.String())
+ })
+ }
+}
+
+func TestFailingExecute(t *testing.T) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.CommandArgs
+ expectedError string
+ }{
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.CommandArgs{},
+ expectedError: "Failed to get username: who='' is invalid",
+ },
+ {
+ desc: "When the API returns an error",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "broken_message"},
+ expectedError: "Failed to get username: Forbidden!",
+ },
+ {
+ desc: "When the API fails",
+ arguments: &commandargs.CommandArgs{GitlabUsername: "broken"},
+ expectedError: "Failed to get username: Internal API error (500)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ cmd := &Command{Config: testConfig, Args: tc.arguments}
+
+ err := cmd.Execute()
+
+ assert.EqualError(t, err, tc.expectedError)
+ })
+ }
+
+}
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
new file mode 100644
index 0000000..abc218f
--- /dev/null
+++ b/go/internal/gitlabnet/client.go
@@ -0,0 +1,77 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+)
+
+const (
+ internalApiPath = "/api/v4/internal"
+ secretHeaderName = "Gitlab-Shared-Secret"
+)
+
+type GitlabClient interface {
+ Get(path string) (*http.Response, error)
+ // TODO: implement posts
+ // Post(path string) (http.Response, error)
+}
+
+type ErrorResponse struct {
+ Message string `json:"message"`
+}
+
+func GetClient(config *config.Config) (GitlabClient, error) {
+ url := config.GitlabUrl
+ if strings.HasPrefix(url, UnixSocketProtocol) {
+ return buildSocketClient(config), nil
+ }
+
+ return nil, fmt.Errorf("Unsupported protocol")
+}
+
+func normalizePath(path string) string {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+
+ if !strings.HasPrefix(path, internalApiPath) {
+ path = internalApiPath + path
+ }
+ return path
+}
+
+func parseError(resp *http.Response) error {
+ if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
+ return nil
+ }
+ defer resp.Body.Close()
+ parsedResponse := &ErrorResponse{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return fmt.Errorf("Internal API error (%v)", resp.StatusCode)
+ } else {
+ return fmt.Errorf(parsedResponse.Message)
+ }
+
+}
+
+func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) {
+ encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret))
+ request.Header.Set(secretHeaderName, encodedSecret)
+
+ response, err := client.Do(request)
+ if err != nil {
+ return nil, fmt.Errorf("Internal API unreachable")
+ }
+
+ if err := parseError(response); err != nil {
+ return nil, err
+ }
+
+ return response, nil
+}
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
new file mode 100644
index 0000000..f69f284
--- /dev/null
+++ b/go/internal/gitlabnet/client_test.go
@@ -0,0 +1,131 @@
+package gitlabnet
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+func TestClients(t *testing.T) {
+ requests := []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/hello",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "Hello")
+ },
+ },
+ {
+ Path: "/api/v4/internal/auth",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, r.Header.Get(secretHeaderName))
+ },
+ },
+ {
+ Path: "/api/v4/internal/error",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ body := map[string]string{
+ "message": "Don't do that",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ {
+ Path: "/api/v4/internal/broken",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ panic("Broken")
+ },
+ },
+ }
+ testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"}
+
+ testCases := []struct {
+ desc string
+ client GitlabClient
+ server func([]testserver.TestRequestHandler) (func(), error)
+ }{
+ {
+ desc: "Socket client",
+ client: buildSocketClient(testConfig),
+ server: testserver.StartSocketHttpServer,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ cleanup, err := tc.server(requests)
+ defer cleanup()
+ require.NoError(t, err)
+
+ testBrokenRequest(t, tc.client)
+ testSuccessfulGet(t, tc.client)
+ testMissing(t, tc.client)
+ testErrorMessage(t, tc.client)
+ testAuthenticationHeader(t, tc.client)
+ })
+ }
+}
+
+func testSuccessfulGet(t *testing.T, client GitlabClient) {
+ t.Run("Successful get", func(t *testing.T) {
+ response, err := client.Get("/hello")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, string(responseBody), "Hello")
+ })
+}
+
+func testMissing(t *testing.T, client GitlabClient) {
+ t.Run("Missing error", func(t *testing.T) {
+ response, err := client.Get("/missing")
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
+}
+
+func testErrorMessage(t *testing.T, client GitlabClient) {
+ t.Run("Error with message", func(t *testing.T) {
+ response, err := client.Get("/error")
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
+}
+
+func testBrokenRequest(t *testing.T, client GitlabClient) {
+ t.Run("Broken request", func(t *testing.T) {
+ response, err := client.Get("/broken")
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
+}
+
+func testAuthenticationHeader(t *testing.T, client GitlabClient) {
+ t.Run("Authentication headers", func(t *testing.T) {
+ response, err := client.Get("/auth")
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
+}
diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go
new file mode 100644
index 0000000..4e65d25
--- /dev/null
+++ b/go/internal/gitlabnet/discover/client.go
@@ -0,0 +1,72 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+)
+
+type Client struct {
+ config *config.Config
+ client gitlabnet.GitlabClient
+}
+
+type Response struct {
+ UserId int64 `json:"id"`
+ Name string `json:"name"`
+ Username string `json:"username"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetByKeyId(keyId string) (*Response, error) {
+ params := url.Values{}
+ params.Add("key_id", keyId)
+
+ return c.getResponse(params)
+}
+
+func (c *Client) GetByUsername(username string) (*Response, error) {
+ params := url.Values{}
+ params.Add("username", username)
+
+ return c.getResponse(params)
+}
+
+func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
+ defer resp.Body.Close()
+ parsedResponse := &Response{}
+
+ if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil {
+ return nil, err
+ } else {
+ return parsedResponse, nil
+ }
+
+}
+
+func (c *Client) getResponse(params url.Values) (*Response, error) {
+ path := "/discover?" + params.Encode()
+ response, err := c.client.Get(path)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return c.parseResponse(response)
+}
+
+func (r *Response) IsAnonymous() bool {
+ return r.UserId < 1
+}
diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go
new file mode 100644
index 0000000..6c87d07
--- /dev/null
+++ b/go/internal/gitlabnet/discover/client_test.go
@@ -0,0 +1,86 @@
+package discover
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func init() {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Query().Get("key_id") == "1" {
+ body := map[string]interface{}{
+ "id": 2,
+ "username": "alex-doe",
+ "name": "Alex Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else if r.URL.Query().Get("username") == "jane-doe" {
+ body := map[string]interface{}{
+ "id": 1,
+ "username": "jane-doe",
+ "name": "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ } else {
+ fmt.Fprint(w, "null")
+
+ }
+
+ },
+ },
+ }
+}
+
+func TestGetByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByKeyId("1")
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result)
+}
+
+func TestGetByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByUsername("jane-doe")
+ assert.NoError(t, err)
+ assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ result, err := client.GetByUsername("missing")
+ assert.NoError(t, err)
+ assert.True(t, result.IsAnonymous())
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+
+ client, err := NewClient(testConfig)
+ require.NoError(t, err)
+
+ return client, cleanup
+}
diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go
new file mode 100644
index 0000000..3bd7c70
--- /dev/null
+++ b/go/internal/gitlabnet/socketclient.go
@@ -0,0 +1,46 @@
+package gitlabnet
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+)
+
+const (
+ // We need to set the base URL to something starting with HTTP, the host
+ // itself is ignored as we're talking over a socket.
+ socketBaseUrl = "http://unix"
+ UnixSocketProtocol = "http+unix://"
+)
+
+type GitlabSocketClient struct {
+ httpClient *http.Client
+ config *config.Config
+}
+
+func buildSocketClient(config *config.Config) *GitlabSocketClient {
+ path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol)
+ httpClient := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", path)
+ },
+ },
+ }
+
+ return &GitlabSocketClient{httpClient: httpClient, config: config}
+}
+
+func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
+ path = normalizePath(path)
+
+ request, err := http.NewRequest("GET", socketBaseUrl+path, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return doRequest(c.httpClient, c.config, request)
+}
diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go
new file mode 100644
index 0000000..9640fd7
--- /dev/null
+++ b/go/internal/gitlabnet/testserver/testserver.go
@@ -0,0 +1,56 @@
+package testserver
+
+import (
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+)
+
+var (
+ tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
+ TestSocket = path.Join(tempDir, "internal.sock")
+)
+
+type TestRequestHandler struct {
+ Path string
+ Handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) {
+ if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil {
+ return nil, err
+ }
+
+ socketListener, err := net.Listen("unix", TestSocket)
+ if err != nil {
+ return nil, err
+ }
+
+ server := http.Server{
+ Handler: buildHandler(handlers),
+ // We'll put this server through some nasty stuff we don't want
+ // in our test output
+ ErrorLog: log.New(ioutil.Discard, "", 0),
+ }
+ go server.Serve(socketListener)
+
+ return cleanupSocket, nil
+}
+
+func cleanupSocket() {
+ os.RemoveAll(tempDir)
+}
+
+func buildHandler(handlers []TestRequestHandler) http.Handler {
+ h := http.NewServeMux()
+
+ for _, handler := range handlers {
+ h.HandleFunc(handler.Path, handler.Handler)
+ }
+
+ return h
+}