summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/gitlab-sshd/main.go1
-rw-r--r--internal/config/config.go6
-rw-r--r--internal/sshd/sshd.go52
-rw-r--r--internal/sshd/sshd_test.go71
4 files changed, 115 insertions, 15 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go
index 7cecbf5..e524023 100644
--- a/cmd/gitlab-sshd/main.go
+++ b/cmd/gitlab-sshd/main.go
@@ -76,6 +76,7 @@ func main() {
monitoring.Start(
monitoring.WithListenerAddress(cfg.Server.WebListen),
monitoring.WithBuildInformation(Version, BuildTime),
+ monitoring.WithServeMux(server.MonitoringServeMux()),
),
)
}()
diff --git a/internal/config/config.go b/internal/config/config.go
index c58ea7d..71e7840 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -26,7 +26,9 @@ type ServerConfig struct {
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
WebListen string `yaml:"web_listen,omitempty"`
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
- GracePeriodSeconds uint64 `yaml:"grace_period"`
+ GracePeriodSeconds uint64 `yaml:"grace_period"`
+ ReadinessProbe string `yaml:"readiness_probe"`
+ LivenessProbe string `yaml:"liveness_probe"`
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
}
@@ -72,6 +74,8 @@ var (
WebListen: "localhost:9122",
ConcurrentSessionsLimit: 10,
GracePeriodSeconds: 10,
+ ReadinessProbe: "/start",
+ LivenessProbe: "/health",
HostKeyFiles: []string{
"/run/secrets/ssh-hostkeys/ssh_host_rsa_key",
"/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key",
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index ef401dc..ac4ebf8 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -10,6 +10,7 @@ import (
"strconv"
"time"
"sync"
+ "net/http"
log "github.com/sirupsen/logrus"
@@ -21,10 +22,20 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
)
+type status int
+
+const(
+ StatusStarting status = iota
+ StatusReady
+ StatusOnShutdown
+ StatusClosed
+)
+
type Server struct {
Config *config.Config
- onShutdown bool
+ status status
+ statusMu sync.Mutex
wg sync.WaitGroup
listener net.Listener
}
@@ -43,11 +54,29 @@ func (s *Server) Shutdown() error {
return nil
}
- s.onShutdown = true
+ s.changeStatus(StatusOnShutdown)
return s.listener.Close()
}
+func (s *Server) MonitoringServeMux() *http.ServeMux {
+ mux := http.NewServeMux()
+
+ mux.HandleFunc(s.Config.Server.ReadinessProbe, func(w http.ResponseWriter, r *http.Request) {
+ if s.getStatus() == StatusReady {
+ w.WriteHeader(http.StatusOK)
+ } else {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }
+ })
+
+ mux.HandleFunc(s.Config.Server.LivenessProbe, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ })
+
+ return mux
+}
+
func (s *Server) listen() error {
sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
@@ -73,10 +102,12 @@ func (s *Server) serve(ctx context.Context) error {
return err
}
+ s.changeStatus(StatusReady)
+
for {
nconn, err := s.listener.Accept()
if err != nil {
- if s.onShutdown {
+ if s.getStatus() == StatusOnShutdown {
break
}
@@ -90,9 +121,24 @@ func (s *Server) serve(ctx context.Context) error {
s.wg.Wait()
+ s.changeStatus(StatusClosed)
+
return nil
}
+func (s *Server) changeStatus(st status) {
+ s.statusMu.Lock()
+ s.status = st
+ s.statusMu.Unlock()
+}
+
+func (s *Server) getStatus() status {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+
+ return s.status
+}
+
func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
authorizedKeysClient, err := authorizedkeys.NewClient(s.Config)
if err != nil {
diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go
index d1891ec..9187140 100644
--- a/internal/sshd/sshd_test.go
+++ b/internal/sshd/sshd_test.go
@@ -4,6 +4,8 @@ import (
"testing"
"context"
"path"
+ "net/http/httptest"
+ "time"
"github.com/stretchr/testify/require"
@@ -17,18 +19,55 @@ const serverUrl = "127.0.0.1:50000"
func TestShutdown(t *testing.T) {
s := setupServer(t)
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()
- done := make(chan bool, 1)
- go func() {
- require.NoError(t, s.serve(ctx))
- done <- true
- }()
+ verifyStatus(t, s, StatusReady)
+
+ s.wg.Add(1)
require.NoError(t, s.Shutdown())
+ verifyStatus(t, s, StatusOnShutdown)
+
+ s.wg.Done()
+
+ verifyStatus(t, s, StatusClosed)
+}
+
+func TestReadinessProbe(t *testing.T) {
+ s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}}
+
+ require.Equal(t, StatusStarting, s.getStatus())
+
+ mux := s.MonitoringServeMux()
+
+ req := httptest.NewRequest("GET", "/start", nil)
+
+ r := httptest.NewRecorder()
+ mux.ServeHTTP(r, req)
+ require.Equal(t, 503, r.Result().StatusCode)
+
+ s.changeStatus(StatusReady)
- require.True(t, <-done, "the accepting loop must be interrupted")
+ r = httptest.NewRecorder()
+ mux.ServeHTTP(r, req)
+ require.Equal(t, 200, r.Result().StatusCode)
+
+ s.changeStatus(StatusOnShutdown)
+
+ r = httptest.NewRecorder()
+ mux.ServeHTTP(r, req)
+ require.Equal(t, 503, r.Result().StatusCode)
+}
+
+func TestLivenessProbe(t *testing.T) {
+ s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}}
+ mux := s.MonitoringServeMux()
+
+ req := httptest.NewRequest("GET", "/health", nil)
+
+ r := httptest.NewRecorder()
+ mux.ServeHTTP(r, req)
+ require.Equal(t, 200, r.Result().StatusCode)
}
func setupServer(t *testing.T) *Server {
@@ -42,8 +81,18 @@ func setupServer(t *testing.T) *Server {
cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg}
- s := &Server{Config: cfg}
- require.NoError(t, s.listen())
+ return &Server{Config: cfg}
+}
+
+func verifyStatus(t *testing.T, s *Server, st status) {
+ for i := 5; i < 500; i+=50 {
+ if s.getStatus() == st {
+ break
+ }
+
+ // Sleep incrementally ~2s in total
+ time.Sleep(time.Duration(i) * time.Millisecond)
+ }
- return s
+ require.Equal(t, s.getStatus(), st)
}