summaryrefslogtreecommitdiff
path: root/go/internal/gitlabnet
diff options
context:
space:
mode:
authorIgor <idrozdov@gitlab.com>2019-03-21 11:53:09 +0000
committerNick Thomas <nick@gitlab.com>2019-03-21 11:53:09 +0000
commit98dbdfb758703428626d54b2a257565a44509a55 (patch)
treea3fdc408786fd0342bd3eb28ad841e70d3d7ac6e /go/internal/gitlabnet
parent81bed658f083a165e65b16f7ef86c18938349e33 (diff)
downloadgitlab-shell-98dbdfb758703428626d54b2a257565a44509a55.tar.gz
Provide go implementation for 2fa_recovery_codes command
Diffstat (limited to 'go/internal/gitlabnet')
-rw-r--r--go/internal/gitlabnet/client.go3
-rw-r--r--go/internal/gitlabnet/client_test.go73
-rw-r--r--go/internal/gitlabnet/discover/client.go13
-rw-r--r--go/internal/gitlabnet/socketclient.go20
-rw-r--r--go/internal/gitlabnet/twofactorrecover/client.go104
-rw-r--r--go/internal/gitlabnet/twofactorrecover/client_test.go161
6 files changed, 368 insertions, 6 deletions
diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go
index abc218f..c2453e5 100644
--- a/go/internal/gitlabnet/client.go
+++ b/go/internal/gitlabnet/client.go
@@ -17,8 +17,7 @@ const (
type GitlabClient interface {
Get(path string) (*http.Response, error)
- // TODO: implement posts
- // Post(path string) (http.Response, error)
+ Post(path string, data interface{}) (*http.Response, error)
}
type ErrorResponse struct {
diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go
index f69f284..c1d08a1 100644
--- a/go/internal/gitlabnet/client_test.go
+++ b/go/internal/gitlabnet/client_test.go
@@ -19,10 +19,25 @@ func TestClients(t *testing.T) {
{
Path: "/api/v4/internal/hello",
Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+
fmt.Fprint(w, "Hello")
},
},
{
+ Path: "/api/v4/internal/post_endpoint",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ fmt.Fprint(w, "Echo: "+string(b))
+ },
+ },
+ {
Path: "/api/v4/internal/auth",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, r.Header.Get(secretHeaderName))
@@ -68,6 +83,7 @@ func TestClients(t *testing.T) {
testBrokenRequest(t, tc.client)
testSuccessfulGet(t, tc.client)
+ testSuccessfulPost(t, tc.client)
testMissing(t, tc.client)
testErrorMessage(t, tc.client)
testAuthenticationHeader(t, tc.client)
@@ -89,32 +105,66 @@ func testSuccessfulGet(t *testing.T, client GitlabClient) {
})
}
+func testSuccessfulPost(t *testing.T, client GitlabClient) {
+ t.Run("Successful Post", func(t *testing.T) {
+ data := map[string]string{"key": "value"}
+
+ response, err := client.Post("/post_endpoint", data)
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
+ })
+}
+
func testMissing(t *testing.T, client GitlabClient) {
- t.Run("Missing error", func(t *testing.T) {
+ t.Run("Missing error for GET", func(t *testing.T) {
response, err := client.Get("/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
})
+
+ t.Run("Missing error for POST", func(t *testing.T) {
+ response, err := client.Post("/missing", map[string]string{})
+ assert.EqualError(t, err, "Internal API error (404)")
+ assert.Nil(t, response)
+ })
}
func testErrorMessage(t *testing.T, client GitlabClient) {
- t.Run("Error with message", func(t *testing.T) {
+ t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
+
+ t.Run("Error with message for POST", func(t *testing.T) {
+ response, err := client.Post("/error", map[string]string{})
+ assert.EqualError(t, err, "Don't do that")
+ assert.Nil(t, response)
+ })
}
func testBrokenRequest(t *testing.T, client GitlabClient) {
- t.Run("Broken request", func(t *testing.T) {
+ t.Run("Broken request for GET", func(t *testing.T) {
response, err := client.Get("/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
})
+
+ t.Run("Broken request for POST", func(t *testing.T) {
+ response, err := client.Post("/broken", map[string]string{})
+ assert.EqualError(t, err, "Internal API unreachable")
+ assert.Nil(t, response)
+ })
}
func testAuthenticationHeader(t *testing.T, client GitlabClient) {
- t.Run("Authentication headers", func(t *testing.T) {
+ t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth")
defer response.Body.Close()
@@ -128,4 +178,19 @@ func testAuthenticationHeader(t *testing.T, client GitlabClient) {
require.NoError(t, err)
assert.Equal(t, "sssh, it's a secret", string(header))
})
+
+ t.Run("Authentication headers for POST", func(t *testing.T) {
+ response, err := client.Post("/auth", map[string]string{})
+ defer response.Body.Close()
+
+ require.NoError(t, err)
+ require.NotNil(t, response)
+
+ responseBody, err := ioutil.ReadAll(response.Body)
+ require.NoError(t, err)
+
+ header, err := base64.StdEncoding.DecodeString(string(responseBody))
+ require.NoError(t, err)
+ assert.Equal(t, "sssh, it's a secret", string(header))
+ })
}
diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go
index 8df78fb..e84b1b4 100644
--- a/go/internal/gitlabnet/discover/client.go
+++ b/go/internal/gitlabnet/discover/client.go
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
)
@@ -30,6 +31,18 @@ func NewClient(config *config.Config) (*Client, error) {
return &Client{config: config, client: client}, nil
}
+func (c *Client) GetByCommandArgs(args *commandargs.CommandArgs) (*Response, error) {
+ if args.GitlabKeyId != "" {
+ return c.GetByKeyId(args.GitlabKeyId)
+ } else if args.GitlabUsername != "" {
+ return c.GetByUsername(args.GitlabUsername)
+ } else {
+ // There was no 'who' information, this matches the ruby error
+ // message.
+ return nil, fmt.Errorf("who='' is invalid")
+ }
+}
+
func (c *Client) GetByKeyId(keyId string) (*Response, error) {
params := url.Values{}
params.Add("key_id", keyId)
diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go
index 3bd7c70..fd97535 100644
--- a/go/internal/gitlabnet/socketclient.go
+++ b/go/internal/gitlabnet/socketclient.go
@@ -1,7 +1,9 @@
package gitlabnet
import (
+ "bytes"
"context"
+ "encoding/json"
"net"
"net/http"
"strings"
@@ -44,3 +46,21 @@ func (c *GitlabSocketClient) Get(path string) (*http.Response, error) {
return doRequest(c.httpClient, c.config, request)
}
+
+func (c *GitlabSocketClient) Post(path string, data interface{}) (*http.Response, error) {
+ path = normalizePath(path)
+
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return nil, err
+ }
+
+ request, err := http.NewRequest("POST", socketBaseUrl+path, bytes.NewReader(jsonData))
+ request.Header.Add("Content-Type", "application/json")
+
+ if err != nil {
+ return nil, err
+ }
+
+ return doRequest(c.httpClient, c.config, request)
+}
diff --git a/go/internal/gitlabnet/twofactorrecover/client.go b/go/internal/gitlabnet/twofactorrecover/client.go
new file mode 100644
index 0000000..2e47c64
--- /dev/null
+++ b/go/internal/gitlabnet/twofactorrecover/client.go
@@ -0,0 +1,104 @@
+package twofactorrecover
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
+)
+
+type Client struct {
+ config *config.Config
+ client gitlabnet.GitlabClient
+}
+
+type Response struct {
+ Success bool `json:"success"`
+ RecoveryCodes []string `json:"recovery_codes"`
+ Message string `json:"message"`
+}
+
+type RequestBody struct {
+ KeyId string `json:"key_id,omitempty"`
+ UserId int64 `json:"user_id,omitempty"`
+}
+
+func NewClient(config *config.Config) (*Client, error) {
+ client, err := gitlabnet.GetClient(config)
+ if err != nil {
+ return nil, fmt.Errorf("Error creating http client: %v", err)
+ }
+
+ return &Client{config: config, client: client}, nil
+}
+
+func (c *Client) GetRecoveryCodes(args *commandargs.CommandArgs) ([]string, error) {
+ requestBody, err := c.getRequestBody(args)
+
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := c.client.Post("/two_factor_recovery_codes", requestBody)
+
+ if err != nil {
+ return nil, err
+ }
+
+ defer response.Body.Close()
+ parsedResponse, err := c.parseResponse(response)
+
+ if err != nil {
+ return nil, fmt.Errorf("Parsing failed")
+ }
+
+ if parsedResponse.Success {
+ return parsedResponse.RecoveryCodes, nil
+ } else {
+ return nil, errors.New(parsedResponse.Message)
+ }
+}
+
+func (c *Client) parseResponse(resp *http.Response) (*Response, error) {
+ parsedResponse := &Response{}
+ body, err := ioutil.ReadAll(resp.Body)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if err := json.Unmarshal(body, parsedResponse); err != nil {
+ return nil, err
+ } else {
+ return parsedResponse, nil
+ }
+}
+
+func (c *Client) getRequestBody(args *commandargs.CommandArgs) (*RequestBody, error) {
+ client, err := discover.NewClient(c.config)
+
+ if err != nil {
+ return nil, err
+ }
+
+ var requestBody *RequestBody
+ if args.GitlabKeyId != "" {
+ requestBody = &RequestBody{KeyId: args.GitlabKeyId}
+ } else {
+ userInfo, err := client.GetByCommandArgs(args)
+
+ if err != nil {
+ return nil, err
+ }
+
+ requestBody = &RequestBody{UserId: userInfo.UserId}
+ }
+
+ return requestBody, nil
+}
diff --git a/go/internal/gitlabnet/twofactorrecover/client_test.go b/go/internal/gitlabnet/twofactorrecover/client_test.go
new file mode 100644
index 0000000..5cbc011
--- /dev/null
+++ b/go/internal/gitlabnet/twofactorrecover/client_test.go
@@ -0,0 +1,161 @@
+package twofactorrecover
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func initialize(t *testing.T) {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/two_factor_recovery_codes",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ var requestBody *RequestBody
+ json.Unmarshal(b, &requestBody)
+
+ switch requestBody.KeyId {
+ case "0":
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery 1", "codes 1"},
+ }
+ json.NewEncoder(w).Encode(body)
+ case "1":
+ body := map[string]interface{}{
+ "success": false,
+ "message": "missing user",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "2":
+ w.WriteHeader(http.StatusForbidden)
+ body := &gitlabnet.ErrorResponse{
+ Message: "Not allowed!",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "3":
+ w.Write([]byte("{ \"message\": \"broken json!\""))
+ case "4":
+ w.WriteHeader(http.StatusForbidden)
+ }
+
+ if requestBody.UserId == 1 {
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery 2", "codes 2"},
+ }
+ json.NewEncoder(w).Encode(body)
+ }
+ },
+ },
+ {
+ Path: "/api/v4/internal/discover",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ body := &discover.Response{
+ UserId: 1,
+ Username: "jane-doe",
+ Name: "Jane Doe",
+ }
+ json.NewEncoder(w).Encode(body)
+ },
+ },
+ }
+}
+
+func TestGetRecoveryCodesByKeyId(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.CommandArgs{GitlabKeyId: "0"}
+ result, err := client.GetRecoveryCodes(args)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"recovery 1", "codes 1"}, result)
+}
+
+func TestGetRecoveryCodesByUsername(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.CommandArgs{GitlabUsername: "jane-doe"}
+ result, err := client.GetRecoveryCodes(args)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"recovery 2", "codes 2"}, result)
+}
+
+func TestMissingUser(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ args := &commandargs.CommandArgs{GitlabKeyId: "1"}
+ _, err := client.GetRecoveryCodes(args)
+ assert.Equal(t, "missing user", err.Error())
+}
+
+func TestErrorResponses(t *testing.T) {
+ client, cleanup := setup(t)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ fakeId string
+ expectedError string
+ }{
+ {
+ desc: "A response with an error message",
+ fakeId: "2",
+ expectedError: "Not allowed!",
+ },
+ {
+ desc: "A response with bad JSON",
+ fakeId: "3",
+ expectedError: "Parsing failed",
+ },
+ {
+ desc: "An error response without message",
+ fakeId: "4",
+ expectedError: "Internal API error (403)",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ args := &commandargs.CommandArgs{GitlabKeyId: tc.fakeId}
+ resp, err := client.GetRecoveryCodes(args)
+
+ assert.EqualError(t, err, tc.expectedError)
+ assert.Nil(t, resp)
+ })
+ }
+}
+
+func setup(t *testing.T) (*Client, func()) {
+ initialize(t)
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+
+ client, err := NewClient(testConfig)
+ require.NoError(t, err)
+
+ return client, cleanup
+}