summaryrefslogtreecommitdiff
path: root/go/internal/command
diff options
context:
space:
mode:
authorIgor <idrozdov@gitlab.com>2019-03-21 11:53:09 +0000
committerNick Thomas <nick@gitlab.com>2019-03-21 11:53:09 +0000
commit98dbdfb758703428626d54b2a257565a44509a55 (patch)
treea3fdc408786fd0342bd3eb28ad841e70d3d7ac6e /go/internal/command
parent81bed658f083a165e65b16f7ef86c18938349e33 (diff)
downloadgitlab-shell-98dbdfb758703428626d54b2a257565a44509a55.tar.gz
Provide go implementation for 2fa_recovery_codes command
Diffstat (limited to 'go/internal/command')
-rw-r--r--go/internal/command/command.go7
-rw-r--r--go/internal/command/command_test.go14
-rw-r--r--go/internal/command/commandargs/command_args.go7
-rw-r--r--go/internal/command/discover/discover.go18
-rw-r--r--go/internal/command/discover/discover_test.go6
-rw-r--r--go/internal/command/fallback/fallback.go4
-rw-r--r--go/internal/command/readwriter/readwriter.go9
-rw-r--r--go/internal/command/reporting/reporter.go8
-rw-r--r--go/internal/command/twofactorrecover/twofactorrecover.go64
-rw-r--r--go/internal/command/twofactorrecover/twofactorrecover_test.go135
10 files changed, 243 insertions, 29 deletions
diff --git a/go/internal/command/command.go b/go/internal/command/command.go
index d4649de..b3bdcba 100644
--- a/go/internal/command/command.go
+++ b/go/internal/command/command.go
@@ -4,12 +4,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
)
type Command interface {
- Execute(*reporting.Reporter) error
+ Execute(*readwriter.ReadWriter) error
}
func New(arguments []string, config *config.Config) (Command, error) {
@@ -30,6 +31,8 @@ func buildCommand(args *commandargs.CommandArgs, config *config.Config) Command
switch args.CommandType {
case commandargs.Discover:
return &discover.Command{Config: config, Args: args}
+ case commandargs.TwoFactorRecover:
+ return &twofactorrecover.Command{Config: config, Args: args}
}
return nil
diff --git a/go/internal/command/command_test.go b/go/internal/command/command_test.go
index 02fc0d0..42c5112 100644
--- a/go/internal/command/command_test.go
+++ b/go/internal/command/command_test.go
@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/twofactorrecover"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/testhelper"
)
@@ -44,6 +45,19 @@ func TestNew(t *testing.T) {
},
expectedType: &fallback.Command{},
},
+ {
+ desc: "it returns a TwoFactorRecover command if the feature is enabled",
+ arguments: []string{},
+ config: &config.Config{
+ GitlabUrl: "http+unix://gitlab.socket",
+ Migration: config.MigrationConfig{Enabled: true, Features: []string{"2fa_recovery_codes"}},
+ },
+ environment: map[string]string{
+ "SSH_CONNECTION": "1",
+ "SSH_ORIGINAL_COMMAND": "2fa_recovery_codes",
+ },
+ expectedType: &twofactorrecover.Command{},
+ },
}
for _, tc := range testCases {
diff --git a/go/internal/command/commandargs/command_args.go b/go/internal/command/commandargs/command_args.go
index 9e679d3..e801889 100644
--- a/go/internal/command/commandargs/command_args.go
+++ b/go/internal/command/commandargs/command_args.go
@@ -9,7 +9,8 @@ import (
type CommandType string
const (
- Discover CommandType = "discover"
+ Discover CommandType = "discover"
+ TwoFactorRecover CommandType = "2fa_recovery_codes"
)
var (
@@ -79,4 +80,8 @@ func (c *CommandArgs) parseCommand(commandString string) {
if commandString == "" {
c.CommandType = Discover
}
+
+ if CommandType(commandString) == TwoFactorRecover {
+ c.CommandType = TwoFactorRecover
+ }
}
diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go
index 8ad2868..9bb442f 100644
--- a/go/internal/command/discover/discover.go
+++ b/go/internal/command/discover/discover.go
@@ -4,7 +4,7 @@ import (
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover"
)
@@ -14,16 +14,16 @@ type Command struct {
Args *commandargs.CommandArgs
}
-func (c *Command) Execute(reporter *reporting.Reporter) error {
+func (c *Command) Execute(readWriter *readwriter.ReadWriter) error {
response, err := c.getUserInfo()
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
if response.IsAnonymous() {
- fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n")
+ fmt.Fprintf(readWriter.Out, "Welcome to GitLab, Anonymous!\n")
} else {
- fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username)
+ fmt.Fprintf(readWriter.Out, "Welcome to GitLab, @%s!\n", response.Username)
}
return nil
@@ -35,13 +35,5 @@ func (c *Command) getUserInfo() (*discover.Response, error) {
return nil, err
}
- if c.Args.GitlabKeyId != "" {
- return client.GetByKeyId(c.Args.GitlabKeyId)
- } else if c.Args.GitlabUsername != "" {
- return client.GetByUsername(c.Args.GitlabUsername)
- } else {
- // There was no 'who' information, this matches the ruby error
- // message.
- return nil, fmt.Errorf("who='' is invalid")
- }
+ return client.GetByCommandArgs(c.Args)
}
diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go
index ec6f931..a57f07e 100644
--- a/go/internal/command/discover/discover_test.go
+++ b/go/internal/command/discover/discover_test.go
@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
)
@@ -82,7 +82,7 @@ func TestExecute(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{}
- err := cmd.Execute(&reporting.Reporter{Out: buffer})
+ err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
assert.NoError(t, err)
assert.Equal(t, tc.expectedOutput, buffer.String())
@@ -122,7 +122,7 @@ func TestFailingExecute(t *testing.T) {
cmd := &Command{Config: testConfig, Args: tc.arguments}
buffer := &bytes.Buffer{}
- err := cmd.Execute(&reporting.Reporter{Out: buffer})
+ err := cmd.Execute(&readwriter.ReadWriter{Out: buffer})
assert.Empty(t, buffer.String())
assert.EqualError(t, err, tc.expectedError)
diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go
index a2c73ed..6e6d526 100644
--- a/go/internal/command/fallback/fallback.go
+++ b/go/internal/command/fallback/fallback.go
@@ -5,7 +5,7 @@ import (
"path/filepath"
"syscall"
- "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
)
type Command struct{}
@@ -14,7 +14,7 @@ var (
binDir = filepath.Dir(os.Args[0])
)
-func (c *Command) Execute(_ *reporting.Reporter) error {
+func (c *Command) Execute(_ *readwriter.ReadWriter) error {
rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby")
execErr := syscall.Exec(rubyCmd, os.Args, os.Environ())
return execErr
diff --git a/go/internal/command/readwriter/readwriter.go b/go/internal/command/readwriter/readwriter.go
new file mode 100644
index 0000000..da18d30
--- /dev/null
+++ b/go/internal/command/readwriter/readwriter.go
@@ -0,0 +1,9 @@
+package readwriter
+
+import "io"
+
+type ReadWriter struct {
+ Out io.Writer
+ In io.Reader
+ ErrOut io.Writer
+}
diff --git a/go/internal/command/reporting/reporter.go b/go/internal/command/reporting/reporter.go
deleted file mode 100644
index 74bca59..0000000
--- a/go/internal/command/reporting/reporter.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package reporting
-
-import "io"
-
-type Reporter struct {
- Out io.Writer
- ErrOut io.Writer
-}
diff --git a/go/internal/command/twofactorrecover/twofactorrecover.go b/go/internal/command/twofactorrecover/twofactorrecover.go
new file mode 100644
index 0000000..e77a334
--- /dev/null
+++ b/go/internal/command/twofactorrecover/twofactorrecover.go
@@ -0,0 +1,64 @@
+package twofactorrecover
+
+import (
+ "fmt"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover"
+)
+
+type Command struct {
+ Config *config.Config
+ Args *commandargs.CommandArgs
+}
+
+func (c *Command) Execute(readWriter *readwriter.ReadWriter) error {
+ if c.canContinue(readWriter) {
+ c.displayRecoveryCodes(readWriter)
+ } else {
+ fmt.Fprintln(readWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
+ }
+
+ return nil
+}
+
+func (c *Command) canContinue(readWriter *readwriter.ReadWriter) bool {
+ question :=
+ "Are you sure you want to generate new two-factor recovery codes?\n" +
+ "Any existing recovery codes you saved will be invalidated. (yes/no)"
+ fmt.Fprintln(readWriter.Out, question)
+
+ var answer string
+ fmt.Fscanln(readWriter.In, &answer)
+
+ return answer == "yes"
+}
+
+func (c *Command) displayRecoveryCodes(readWriter *readwriter.ReadWriter) {
+ codes, err := c.getRecoveryCodes()
+
+ if err == nil {
+ messageWithCodes :=
+ "\nYour two-factor authentication recovery codes are:\n\n" +
+ strings.Join(codes, "\n") +
+ "\n\nDuring sign in, use one of the codes above when prompted for\n" +
+ "your two-factor code. Then, visit your Profile Settings and add\n" +
+ "a new device so you do not lose access to your account again.\n"
+ fmt.Fprint(readWriter.Out, messageWithCodes)
+ } else {
+ fmt.Fprintf(readWriter.Out, "\nAn error occurred while trying to generate new recovery codes.\n%v\n", err)
+ }
+}
+
+func (c *Command) getRecoveryCodes() ([]string, error) {
+ client, err := twofactorrecover.NewClient(c.Config)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return client.GetRecoveryCodes(c.Args)
+}
diff --git a/go/internal/command/twofactorrecover/twofactorrecover_test.go b/go/internal/command/twofactorrecover/twofactorrecover_test.go
new file mode 100644
index 0000000..908ee13
--- /dev/null
+++ b/go/internal/command/twofactorrecover/twofactorrecover_test.go
@@ -0,0 +1,135 @@
+package twofactorrecover
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/readwriter"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/config"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver"
+ "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/twofactorrecover"
+)
+
+var (
+ testConfig *config.Config
+ requests []testserver.TestRequestHandler
+)
+
+func setup(t *testing.T) {
+ testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket}
+ requests = []testserver.TestRequestHandler{
+ {
+ Path: "/api/v4/internal/two_factor_recovery_codes",
+ Handler: func(w http.ResponseWriter, r *http.Request) {
+ b, err := ioutil.ReadAll(r.Body)
+ defer r.Body.Close()
+
+ require.NoError(t, err)
+
+ var requestBody *twofactorrecover.RequestBody
+ json.Unmarshal(b, &requestBody)
+
+ switch requestBody.KeyId {
+ case "1":
+ body := map[string]interface{}{
+ "success": true,
+ "recovery_codes": [2]string{"recovery", "codes"},
+ }
+ json.NewEncoder(w).Encode(body)
+ case "forbidden":
+ body := map[string]interface{}{
+ "success": false,
+ "message": "Forbidden!",
+ }
+ json.NewEncoder(w).Encode(body)
+ case "broken":
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ },
+ },
+ }
+}
+
+const (
+ question = "Are you sure you want to generate new two-factor recovery codes?\n" +
+ "Any existing recovery codes you saved will be invalidated. (yes/no)\n\n"
+ errorHeader = "An error occurred while trying to generate new recovery codes.\n"
+)
+
+func TestExecute(t *testing.T) {
+ setup(t)
+
+ cleanup, err := testserver.StartSocketHttpServer(requests)
+ require.NoError(t, err)
+ defer cleanup()
+
+ testCases := []struct {
+ desc string
+ arguments *commandargs.CommandArgs
+ answer string
+ expectedOutput string
+ }{
+ {
+ desc: "With a known key id",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "1"},
+ answer: "yes\n",
+ expectedOutput: question +
+ "Your two-factor authentication recovery codes are:\n\nrecovery\ncodes\n\n" +
+ "During sign in, use one of the codes above when prompted for\n" +
+ "your two-factor code. Then, visit your Profile Settings and add\n" +
+ "a new device so you do not lose access to your account again.\n",
+ },
+ {
+ desc: "With bad response",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Parsing failed\n",
+ },
+ {
+ desc: "With API returns an error",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "forbidden"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Forbidden!\n",
+ },
+ {
+ desc: "With API fails",
+ arguments: &commandargs.CommandArgs{GitlabKeyId: "broken"},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "Internal API error (500)\n",
+ },
+ {
+ desc: "With missing arguments",
+ arguments: &commandargs.CommandArgs{},
+ answer: "yes\n",
+ expectedOutput: question + errorHeader + "who='' is invalid\n",
+ },
+ {
+ desc: "With negative answer",
+ arguments: &commandargs.CommandArgs{},
+ answer: "no\n",
+ expectedOutput: question +
+ "New recovery codes have *not* been generated. Existing codes will remain valid.\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ output := &bytes.Buffer{}
+ input := bytes.NewBufferString(tc.answer)
+
+ cmd := &Command{Config: testConfig, Args: tc.arguments}
+
+ err := cmd.Execute(&readwriter.ReadWriter{Out: output, In: input})
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.expectedOutput, output.String())
+ })
+ }
+}