summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/command/shared/customaction/customaction.go28
-rw-r--r--internal/command/shared/customaction/customaction_test.go4
-rw-r--r--internal/pktline/pktline.go14
-rw-r--r--internal/pktline/pktline_test.go34
4 files changed, 74 insertions, 6 deletions
diff --git a/internal/command/shared/customaction/customaction.go b/internal/command/shared/customaction/customaction.go
index 0675d36..d91f8ab 100644
--- a/internal/command/shared/customaction/customaction.go
+++ b/internal/command/shared/customaction/customaction.go
@@ -112,10 +112,32 @@ func (c *Command) performRequest(ctx context.Context, client *client.GitlabNetCl
}
func (c *Command) readFromStdin() ([]byte, error) {
- output := new(bytes.Buffer)
- _, err := io.Copy(output, c.ReadWriter.In)
+ var output []byte
+ var needsPackData bool
+
+ scanner := pktline.NewScanner(c.ReadWriter.In)
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ output = append(output, line...)
+
+ if pktline.IsFlush(line) {
+ break
+ }
- return output.Bytes(), err
+ if !needsPackData && !pktline.IsRefRemoval(line) {
+ needsPackData = true
+ }
+ }
+
+ if needsPackData {
+ packData := new(bytes.Buffer)
+ _, err := io.Copy(packData, c.ReadWriter.In)
+
+ output = append(output, packData.Bytes()...)
+ return output, err
+ } else {
+ return output, nil
+ }
}
func (c *Command) readFromStdinNoEOF() []byte {
diff --git a/internal/command/shared/customaction/customaction_test.go b/internal/command/shared/customaction/customaction_test.go
index 87ae2e4..d3e794c 100644
--- a/internal/command/shared/customaction/customaction_test.go
+++ b/internal/command/shared/customaction/customaction_test.go
@@ -46,7 +46,7 @@ func TestExecuteEOFSent(t *testing.T) {
require.NoError(t, json.Unmarshal(b, &request))
require.Equal(t, request.Data.UserId, who)
- require.Equal(t, "input", string(request.Output))
+ require.Equal(t, "0009input", string(request.Output))
err = json.NewEncoder(w).Encode(Response{Result: []byte("output")})
require.NoError(t, err)
@@ -58,7 +58,7 @@ func TestExecuteEOFSent(t *testing.T) {
outBuf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
- input := bytes.NewBufferString("input")
+ input := bytes.NewBufferString("0009input")
response := &accessverifier.Response{
Who: who,
diff --git a/internal/pktline/pktline.go b/internal/pktline/pktline.go
index 35fceb2..c091d82 100644
--- a/internal/pktline/pktline.go
+++ b/internal/pktline/pktline.go
@@ -8,6 +8,7 @@ import (
"bytes"
"fmt"
"io"
+ "regexp"
"strconv"
)
@@ -16,6 +17,8 @@ const (
pktDelim = "0001"
)
+var branchRemovalPktRegexp = regexp.MustCompile(`\A[a-f0-9]{4}[a-f0-9]{40} 0{40} `)
+
// NewScanner returns a bufio.Scanner that splits on Git pktline boundaries
func NewScanner(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
@@ -24,7 +27,16 @@ func NewScanner(r io.Reader) *bufio.Scanner {
return scanner
}
-// IsDone detects the special flush packet '0009done\n'
+func IsRefRemoval(pkt []byte) bool {
+ return branchRemovalPktRegexp.Match(pkt)
+}
+
+// IsFlush detects the special flush packet '0000'
+func IsFlush(pkt []byte) bool {
+ return bytes.Equal(pkt, []byte("0000"))
+}
+
+// IsDone detects the special done packet '0009done\n'
func IsDone(pkt []byte) bool {
return bytes.Equal(pkt, PktDone())
}
diff --git a/internal/pktline/pktline_test.go b/internal/pktline/pktline_test.go
index 6910c1e..20a5bf2 100644
--- a/internal/pktline/pktline_test.go
+++ b/internal/pktline/pktline_test.go
@@ -68,6 +68,40 @@ func TestScanner(t *testing.T) {
}
}
+func TestIsRefRemoval(t *testing.T) {
+ testCases := []struct {
+ in string
+ isRemoval bool
+ }{
+ {in: "003f7217a7c7e582c46cec22a130adf4b9d7d950fba0 7d1665144a3a975c05f1f43902ddaf084e784dbe refs/heads/debug", isRemoval: false},
+ {in: "003f0000000000000000000000000000000000000000 7d1665144a3a975c05f1f43902ddaf084e784dbe refs/heads/debug", isRemoval: false},
+ {in: "003f7217a7c7e582c46cec22a130adf4b9d7d950fba0 0000000000000000000000000000000000000000 refs/heads/debug", isRemoval: true},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.in, func(t *testing.T) {
+ require.Equal(t, tc.isRemoval, IsRefRemoval([]byte(tc.in)))
+ })
+ }
+}
+
+func TestIsFlush(t *testing.T) {
+ testCases := []struct {
+ in string
+ flush bool
+ }{
+ {in: "0008abcd", flush: false},
+ {in: "invalid packet", flush: false},
+ {in: "0000", flush: true},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.in, func(t *testing.T) {
+ require.Equal(t, tc.flush, IsFlush([]byte(tc.in)))
+ })
+ }
+}
+
func TestIsDone(t *testing.T) {
testCases := []struct {
in string