summaryrefslogtreecommitdiff
path: root/client/testserver
diff options
context:
space:
mode:
Diffstat (limited to 'client/testserver')
-rw-r--r--client/testserver/gitalyserver.go85
-rw-r--r--client/testserver/testserver.go81
2 files changed, 166 insertions, 0 deletions
diff --git a/client/testserver/gitalyserver.go b/client/testserver/gitalyserver.go
new file mode 100644
index 0000000..4bf14f3
--- /dev/null
+++ b/client/testserver/gitalyserver.go
@@ -0,0 +1,85 @@
+package testserver
+
+import (
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ pb "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+type TestGitalyServer struct{ ReceivedMD metadata.MD }
+
+func (s *TestGitalyServer) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
+
+ response := []byte("ReceivePack: " + req.GlId + " " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHReceivePackResponse{Stdout: response})
+
+ return nil
+}
+
+func (s *TestGitalyServer) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
+
+ response := []byte("UploadPack: " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHUploadPackResponse{Stdout: response})
+
+ return nil
+}
+
+func (s *TestGitalyServer) SSHUploadArchive(stream pb.SSHService_SSHUploadArchiveServer) error {
+ req, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ s.ReceivedMD, _ = metadata.FromIncomingContext(stream.Context())
+
+ response := []byte("UploadArchive: " + req.Repository.GlRepository)
+ stream.Send(&pb.SSHUploadArchiveResponse{Stdout: response})
+
+ return nil
+}
+
+func StartGitalyServer(t *testing.T) (string, *TestGitalyServer, func()) {
+ tempDir, _ := ioutil.TempDir("", "gitlab-shell-test-api")
+ gitalySocketPath := path.Join(tempDir, "gitaly.sock")
+
+ err := os.MkdirAll(filepath.Dir(gitalySocketPath), 0700)
+ require.NoError(t, err)
+
+ server := grpc.NewServer()
+
+ listener, err := net.Listen("unix", gitalySocketPath)
+ require.NoError(t, err)
+
+ testServer := TestGitalyServer{}
+ pb.RegisterSSHServiceServer(server, &testServer)
+
+ go server.Serve(listener)
+
+ gitalySocketUrl := "unix:" + gitalySocketPath
+ cleanup := func() {
+ server.Stop()
+ os.RemoveAll(tempDir)
+ }
+
+ return gitalySocketUrl, &testServer, cleanup
+}
diff --git a/client/testserver/testserver.go b/client/testserver/testserver.go
new file mode 100644
index 0000000..377e331
--- /dev/null
+++ b/client/testserver/testserver.go
@@ -0,0 +1,81 @@
+package testserver
+
+import (
+ "crypto/tls"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
+)
+
+var (
+ tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api")
+ testSocket = path.Join(tempDir, "internal.sock")
+)
+
+type TestRequestHandler struct {
+ Path string
+ Handler func(w http.ResponseWriter, r *http.Request)
+}
+
+func StartSocketHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ err := os.MkdirAll(filepath.Dir(testSocket), 0700)
+ require.NoError(t, err)
+
+ socketListener, err := net.Listen("unix", testSocket)
+ require.NoError(t, err)
+
+ server := http.Server{
+ Handler: buildHandler(handlers),
+ // We'll put this server through some nasty stuff we don't want
+ // in our test output
+ ErrorLog: log.New(ioutil.Discard, "", 0),
+ }
+ go server.Serve(socketListener)
+
+ url := "http+unix://" + testSocket
+
+ return url, cleanupSocket
+}
+
+func StartHttpServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ server := httptest.NewServer(buildHandler(handlers))
+
+ return server.URL, server.Close
+}
+
+func StartHttpsServer(t *testing.T, handlers []TestRequestHandler) (string, func()) {
+ crt := path.Join(testhelper.TestRoot, "certs/valid/server.crt")
+ key := path.Join(testhelper.TestRoot, "certs/valid/server.key")
+
+ server := httptest.NewUnstartedServer(buildHandler(handlers))
+ cer, err := tls.LoadX509KeyPair(crt, key)
+ require.NoError(t, err)
+
+ server.TLS = &tls.Config{Certificates: []tls.Certificate{cer}}
+ server.StartTLS()
+
+ return server.URL, server.Close
+}
+
+func cleanupSocket() {
+ os.RemoveAll(tempDir)
+}
+
+func buildHandler(handlers []TestRequestHandler) http.Handler {
+ h := http.NewServeMux()
+
+ for _, handler := range handlers {
+ h.HandleFunc(handler.Path, handler.Handler)
+ }
+
+ return h
+}