diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/command/shared/customaction/customaction.go | 28 | ||||
-rw-r--r-- | internal/command/shared/customaction/customaction_test.go | 4 | ||||
-rw-r--r-- | internal/pktline/pktline.go | 14 | ||||
-rw-r--r-- | internal/pktline/pktline_test.go | 34 |
4 files changed, 74 insertions, 6 deletions
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go index 0675d36..d91f8ab 100644 --- a/internal/command/shared/customaction/customaction.go +++ b/internal/command/shared/customaction/customaction.go @@ -112,10 +112,32 @@ func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetCl } func (c *Command) readFromStdin() ([]byte, error) { - output := new(bytes.Buffer) - _, err := io.Copy(output, c.ReadWriter.In) + var output []byte + var needsPackData bool + + scanner := pktline.NewScanner(c.ReadWriter.In) + for scanner.Scan() { + line := scanner.Bytes() + output = append(output, line...) + + if pktline.IsFlush(line) { + break + } - return output.Bytes(), err + if !needsPackData && !pktline.IsRefRemoval(line) { + needsPackData = true + } + } + + if needsPackData { + packData := new(bytes.Buffer) + _, err := io.Copy(packData, c.ReadWriter.In) + + output = append(output, packData.Bytes()...) + return output, err + } else { + return output, nil + } } func (c *Command) readFromStdinNoEOF() []byte { diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go index 87ae2e4..d3e794c 100644 --- a/internal/command/shared/customaction/customaction_test.go +++ b/internal/command/shared/customaction/customaction_test.go @@ -46,7 +46,7 @@ func TestExecuteEOFSent(t *testing.T) { require.NoError(t, json.Unmarshal(b, &request)) require.Equal(t, request.Data.UserId, who) - require.Equal(t, "input", string(request.Output)) + require.Equal(t, "0009input", string(request.Output)) err = json.NewEncoder(w).Encode(Response{Result: []byte("output")}) require.NoError(t, err) @@ -58,7 +58,7 @@ func TestExecuteEOFSent(t *testing.T) { outBuf := &bytes.Buffer{} errBuf := &bytes.Buffer{} - input := bytes.NewBufferString("input") + input := bytes.NewBufferString("0009input") response := &accessverifier.Response{ Who: who, diff --git a/internal/pktline/pktline.go b/internal/pktline/pktline.go index 35fceb2..c091d82 100644 --- a/internal/pktline/pktline.go +++ b/internal/pktline/pktline.go @@ -8,6 +8,7 @@ import ( "bytes" "fmt" "io" + "regexp" "strconv" ) @@ -16,6 +17,8 @@ const ( pktDelim = "0001" ) +var branchRemovalPktRegexp = regexp.MustCompile(`\A[a-f0-9]{4}[a-f0-9]{40} 0{40} `) + // NewScanner returns a bufio.Scanner that splits on Git pktline boundaries func NewScanner(r io.Reader) *bufio.Scanner { scanner := bufio.NewScanner(r) @@ -24,7 +27,16 @@ func NewScanner(r io.Reader) *bufio.Scanner { return scanner } -// IsDone detects the special flush packet '0009done\n' +func IsRefRemoval(pkt []byte) bool { + return branchRemovalPktRegexp.Match(pkt) +} + +// IsFlush detects the special flush packet '0000' +func IsFlush(pkt []byte) bool { + return bytes.Equal(pkt, []byte("0000")) +} + +// IsDone detects the special done packet '0009done\n' func IsDone(pkt []byte) bool { return bytes.Equal(pkt, PktDone()) } diff --git a/internal/pktline/pktline_test.go b/internal/pktline/pktline_test.go index 6910c1e..20a5bf2 100644 --- a/internal/pktline/pktline_test.go +++ b/internal/pktline/pktline_test.go @@ -68,6 +68,40 @@ func TestScanner(t *testing.T) { } } +func TestIsRefRemoval(t *testing.T) { + testCases := []struct { + in string + isRemoval bool + }{ + {in: "003f7217a7c7e582c46cec22a130adf4b9d7d950fba0 7d1665144a3a975c05f1f43902ddaf084e784dbe refs/heads/debug", isRemoval: false}, + {in: "003f0000000000000000000000000000000000000000 7d1665144a3a975c05f1f43902ddaf084e784dbe refs/heads/debug", isRemoval: false}, + {in: "003f7217a7c7e582c46cec22a130adf4b9d7d950fba0 0000000000000000000000000000000000000000 refs/heads/debug", isRemoval: true}, + } + + for _, tc := range testCases { + t.Run(tc.in, func(t *testing.T) { + require.Equal(t, tc.isRemoval, IsRefRemoval([]byte(tc.in))) + }) + } +} + +func TestIsFlush(t *testing.T) { + testCases := []struct { + in string + flush bool + }{ + {in: "0008abcd", flush: false}, + {in: "invalid packet", flush: false}, + {in: "0000", flush: true}, + } + + for _, tc := range testCases { + t.Run(tc.in, func(t *testing.T) { + require.Equal(t, tc.flush, IsFlush([]byte(tc.in))) + }) + } +} + func TestIsDone(t *testing.T) { testCases := []struct { in string |