diff options
author | Igor <idrozdov@gitlab.com> | 2019-03-21 11:53:09 +0000 |
---|---|---|
committer | Nick Thomas <nick@gitlab.com> | 2019-03-21 11:53:09 +0000 |
commit | 98dbdfb758703428626d54b2a257565a44509a55 (patch) | |
tree | a3fdc408786fd0342bd3eb28ad841e70d3d7ac6e /go/internal/gitlabnet | |
parent | 81bed658f083a165e65b16f7ef86c18938349e33 (diff) | |
download | gitlab-shell-98dbdfb758703428626d54b2a257565a44509a55.tar.gz |
Provide go implementation for 2fa_recovery_codes command
Diffstat (limited to 'go/internal/gitlabnet')
-rw-r--r-- | go/internal/gitlabnet/client.go | 3 | ||||
-rw-r--r-- | go/internal/gitlabnet/client_test.go | 73 | ||||
-rw-r--r-- | go/internal/gitlabnet/discover/client.go | 13 | ||||
-rw-r--r-- | go/internal/gitlabnet/socketclient.go | 20 | ||||
-rw-r--r-- | go/internal/gitlabnet/twofactorrecover/client.go | 104 | ||||
-rw-r--r-- | go/internal/gitlabnet/twofactorrecover/client_test.go | 161 |
6 files changed, 368 insertions, 6 deletions
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go index abc218f..c2453e5 100644 --- a/go/internal/gitlabnet/client.go +++ b/go/internal/gitlabnet/client.go @@ -17,8 +17,7 @@ const ( type GitlabClient interface { Get(path string) (*http.Response, error) - // TODO: implement posts - // Post(path string) (http.Response, error) + Post(path string, data interface{}) (*http.Response, error) } type ErrorResponse struct { diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go index f69f284..c1d08a1 100644 --- a/go/internal/gitlabnet/client_test.go +++ b/go/internal/gitlabnet/client_test.go @@ -19,10 +19,25 @@ func TestClients(t *testing.T) { { Path: "/api/v4/internal/hello", Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + fmt.Fprint(w, "Hello") }, }, { + Path: "/api/v4/internal/post_endpoint", + Handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + fmt.Fprint(w, "Echo: "+string(b)) + }, + }, + { Path: "/api/v4/internal/auth", Handler: func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, r.Header.Get(secretHeaderName)) @@ -68,6 +83,7 @@ func TestClients(t *testing.T) { testBrokenRequest(t, tc.client) testSuccessfulGet(t, tc.client) + testSuccessfulPost(t, tc.client) testMissing(t, tc.client) testErrorMessage(t, tc.client) testAuthenticationHeader(t, tc.client) @@ -89,32 +105,66 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) { }) } +func testSuccessfulPost(t *testing.T, client GitlabClient) { + t.Run("Successful Post", func(t *testing.T) { + data := map[string]string{"key": "value"} + + response, err := client.Post("/post_endpoint", data) + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody)) + }) +} + func testMissing(t *testing.T, client GitlabClient) { - t.Run("Missing error", func(t *testing.T) { + t.Run("Missing error for GET", func(t *testing.T) { response, err := client.Get("/missing") assert.EqualError(t, err, "Internal API error (404)") assert.Nil(t, response) }) + + t.Run("Missing error for POST", func(t *testing.T) { + response, err := client.Post("/missing", map[string]string{}) + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + }) } func testErrorMessage(t *testing.T, client GitlabClient) { - t.Run("Error with message", func(t *testing.T) { + t.Run("Error with message for GET", func(t *testing.T) { response, err := client.Get("/error") assert.EqualError(t, err, "Don't do that") assert.Nil(t, response) }) + + t.Run("Error with message for POST", func(t *testing.T) { + response, err := client.Post("/error", map[string]string{}) + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) } func testBrokenRequest(t *testing.T, client GitlabClient) { - t.Run("Broken request", func(t *testing.T) { + t.Run("Broken request for GET", func(t *testing.T) { response, err := client.Get("/broken") assert.EqualError(t, err, "Internal API unreachable") assert.Nil(t, response) }) + + t.Run("Broken request for POST", func(t *testing.T) { + response, err := client.Post("/broken", map[string]string{}) + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) } func testAuthenticationHeader(t *testing.T, client GitlabClient) { - t.Run("Authentication headers", func(t *testing.T) { + t.Run("Authentication headers for GET", func(t *testing.T) { response, err := client.Get("/auth") defer response.Body.Close() @@ -128,4 +178,19 @@ func testAuthenticationHeader(t *testing.T, client GitlabClient) { require.NoError(t, err) assert.Equal(t, "sssh, it's a secret", string(header)) }) + + t.Run("Authentication headers for POST", func(t *testing.T) { + response, err := client.Post("/auth", map[string]string{}) + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) } diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go index 8df78fb..e84b1b4 100644 --- a/go/internal/gitlabnet/discover/client.go +++ b/go/internal/gitlabnet/discover/client.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" ) @@ -30,6 +31,18 @@ func NewClient(config *config.Config) (*Client, error) { return &Client{config: config, client: client}, nil } +func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) { + if args.GitlabKeyId != "" { + return c.GetByKeyId(args.GitlabKeyId) + } else if args.GitlabUsername != "" { + return c.GetByUsername(args.GitlabUsername) + } else { + // There was no 'who' information, this matches the ruby error + // message. + return nil, fmt.Errorf("who='' is invalid") + } +} + func (c *Client) GetByKeyId(keyId string) (*Response, error) { params := url.Values{} params.Add("key_id", keyId) diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go index 3bd7c70..fd97535 100644 --- a/go/internal/gitlabnet/socketclient.go +++ b/go/internal/gitlabnet/socketclient.go @@ -1,7 +1,9 @@ package gitlabnet import ( + "bytes" "context" + "encoding/json" "net" "net/http" "strings" @@ -44,3 +46,21 @@ func (c *GitlabSocketClient) Get(path string) (*http.Response, error) { return doRequest(c.httpClient, c.config, request) } + +func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) { + path = normalizePath(path) + + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData)) + request.Header.Add("Content-Type", "application/json") + + if err != nil { + return nil, err + } + + return doRequest(c.httpClient, c.config, request) +} diff --git a/go/internal/gitlabnet/twofactorrecover/client.go b/go/internal/gitlabnet/twofactorrecover/client.go new file mode 100644 index 0000000..2e47c64 --- /dev/null +++ b/go/internal/gitlabnet/twofactorrecover/client.go @@ -0,0 +1,104 @@ +package twofactorrecover + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" +) + +type Client struct { + config *config.Config + client gitlabnet.GitlabClient +} + +type Response struct { + Success bool `json:"success"` + RecoveryCodes []string `json:"recovery_codes"` + Message string `json:"message"` +} + +type RequestBody struct { + KeyId string `json:"key_id,omitempty"` + UserId int64 `json:"user_id,omitempty"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, error) { + requestBody, err := c.getRequestBody(args) + + if err != nil { + return nil, err + } + + response, err := c.client.Post("/two_factor_recovery_codes", requestBody) + + if err != nil { + return nil, err + } + + defer response.Body.Close() + parsedResponse, err := c.parseResponse(response) + + if err != nil { + return nil, fmt.Errorf("Parsing failed") + } + + if parsedResponse.Success { + return parsedResponse.RecoveryCodes, nil + } else { + return nil, errors.New(parsedResponse.Message) + } +} + +func (c *Client) parseResponse(resp *http.Response) (*Response, error) { + parsedResponse := &Response{} + body, err := ioutil.ReadAll(resp.Body) + + if err != nil { + return nil, err + } + + if err := json.Unmarshal(body, parsedResponse); err != nil { + return nil, err + } else { + return parsedResponse, nil + } +} + +func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) { + client, err := discover.NewClient(c.config) + + if err != nil { + return nil, err + } + + var requestBody *RequestBody + if args.GitlabKeyId != "" { + requestBody = &RequestBody{KeyId: args.GitlabKeyId} + } else { + userInfo, err := client.GetByCommandArgs(args) + + if err != nil { + return nil, err + } + + requestBody = &RequestBody{UserId: userInfo.UserId} + } + + return requestBody, nil +} diff --git a/go/internal/gitlabnet/twofactorrecover/client_test.go b/go/internal/gitlabnet/twofactorrecover/client_test.go new file mode 100644 index 0000000..5cbc011 --- /dev/null +++ b/go/internal/gitlabnet/twofactorrecover/client_test.go @@ -0,0 +1,161 @@ +package twofactorrecover + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +var ( + testConfig *config.Config + requests []testserver.TestRequestHandler +) + +func initialize(t *testing.T) { + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/two_factor_recovery_codes", + Handler: func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + defer r.Body.Close() + + require.NoError(t, err) + + var requestBody *RequestBody + json.Unmarshal(b, &requestBody) + + switch requestBody.KeyId { + case "0": + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery 1", "codes 1"}, + } + json.NewEncoder(w).Encode(body) + case "1": + body := map[string]interface{}{ + "success": false, + "message": "missing user", + } + json.NewEncoder(w).Encode(body) + case "2": + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + case "3": + w.Write([]byte("{ \"message\": \"broken json!\"")) + case "4": + w.WriteHeader(http.StatusForbidden) + } + + if requestBody.UserId == 1 { + body := map[string]interface{}{ + "success": true, + "recovery_codes": [2]string{"recovery 2", "codes 2"}, + } + json.NewEncoder(w).Encode(body) + } + }, + }, + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + body := &discover.Response{ + UserId: 1, + Username: "jane-doe", + Name: "Jane Doe", + } + json.NewEncoder(w).Encode(body) + }, + }, + } +} + +func TestGetRecoveryCodesByKeyId(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabKeyId: "0"} + result, err := client.GetRecoveryCodes(args) + assert.NoError(t, err) + assert.Equal(t, []string{"recovery 1", "codes 1"}, result) +} + +func TestGetRecoveryCodesByUsername(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabUsername: "jane-doe"} + result, err := client.GetRecoveryCodes(args) + assert.NoError(t, err) + assert.Equal(t, []string{"recovery 2", "codes 2"}, result) +} + +func TestMissingUser(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + args := &commandargs.CommandArgs{GitlabKeyId: "1"} + _, err := client.GetRecoveryCodes(args) + assert.Equal(t, "missing user", err.Error()) +} + +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeId string + expectedError string + }{ + { + desc: "A response with an error message", + fakeId: "2", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeId: "3", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeId: "4", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + args := &commandargs.CommandArgs{GitlabKeyId: tc.fakeId} + resp, err := client.GetRecoveryCodes(args) + + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + }) + } +} + +func setup(t *testing.T) (*Client, func()) { + initialize(t) + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + + client, err := NewClient(testConfig) + require.NoError(t, err) + + return client, cleanup +} |