diff options
-rw-r--r-- | cmd/gitlab-sshd/main.go | 1 | ||||
-rw-r--r-- | internal/config/config.go | 6 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 52 | ||||
-rw-r--r-- | internal/sshd/sshd_test.go | 71 |
4 files changed, 115 insertions, 15 deletions
diff --git a/cmd/gitlab-sshd/main.go b/cmd/gitlab-sshd/main.go index 7cecbf5..e524023 100644 --- a/cmd/gitlab-sshd/main.go +++ b/cmd/gitlab-sshd/main.go @@ -76,6 +76,7 @@ func main() { monitoring.Start( monitoring.WithListenerAddress(cfg.Server.WebListen), monitoring.WithBuildInformation(Version, BuildTime), + monitoring.WithServeMux(server.MonitoringServeMux()), ), ) }() diff --git a/internal/config/config.go b/internal/config/config.go index c58ea7d..71e7840 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,7 +26,9 @@ type ServerConfig struct { ProxyProtocol bool `yaml:"proxy_protocol,omitempty"` WebListen string `yaml:"web_listen,omitempty"` ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"` - GracePeriodSeconds uint64 `yaml:"grace_period"` + GracePeriodSeconds uint64 `yaml:"grace_period"` + ReadinessProbe string `yaml:"readiness_probe"` + LivenessProbe string `yaml:"liveness_probe"` HostKeyFiles []string `yaml:"host_key_files,omitempty"` } @@ -72,6 +74,8 @@ var ( WebListen: "localhost:9122", ConcurrentSessionsLimit: 10, GracePeriodSeconds: 10, + ReadinessProbe: "/start", + LivenessProbe: "/health", HostKeyFiles: []string{ "/run/secrets/ssh-hostkeys/ssh_host_rsa_key", "/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key", diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index ef401dc..ac4ebf8 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -10,6 +10,7 @@ import ( "strconv" "time" "sync" + "net/http" log "github.com/sirupsen/logrus" @@ -21,10 +22,20 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" ) +type status int + +const( + StatusStarting status = iota + StatusReady + StatusOnShutdown + StatusClosed +) + type Server struct { Config *config.Config - onShutdown bool + status status + statusMu sync.Mutex wg sync.WaitGroup listener net.Listener } @@ -43,11 +54,29 @@ func (s *Server) Shutdown() error { return nil } - s.onShutdown = true + s.changeStatus(StatusOnShutdown) return s.listener.Close() } +func (s *Server) MonitoringServeMux() *http.ServeMux { + mux := http.NewServeMux() + + mux.HandleFunc(s.Config.Server.ReadinessProbe, func(w http.ResponseWriter, r *http.Request) { + if s.getStatus() == StatusReady { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + } + }) + + mux.HandleFunc(s.Config.Server.LivenessProbe, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + return mux +} + func (s *Server) listen() error { sshListener, err := net.Listen("tcp", s.Config.Server.Listen) if err != nil { @@ -73,10 +102,12 @@ func (s *Server) serve(ctx context.Context) error { return err } + s.changeStatus(StatusReady) + for { nconn, err := s.listener.Accept() if err != nil { - if s.onShutdown { + if s.getStatus() == StatusOnShutdown { break } @@ -90,9 +121,24 @@ func (s *Server) serve(ctx context.Context) error { s.wg.Wait() + s.changeStatus(StatusClosed) + return nil } +func (s *Server) changeStatus(st status) { + s.statusMu.Lock() + s.status = st + s.statusMu.Unlock() +} + +func (s *Server) getStatus() status { + s.statusMu.Lock() + defer s.statusMu.Unlock() + + return s.status +} + func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { authorizedKeysClient, err := authorizedkeys.NewClient(s.Config) if err != nil { diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index d1891ec..9187140 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -4,6 +4,8 @@ import ( "testing" "context" "path" + "net/http/httptest" + "time" "github.com/stretchr/testify/require" @@ -17,18 +19,55 @@ const serverUrl = "127.0.0.1:50000" func TestShutdown(t *testing.T) { s := setupServer(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + go func() { require.NoError(t, s.ListenAndServe(context.Background())) }() - done := make(chan bool, 1) - go func() { - require.NoError(t, s.serve(ctx)) - done <- true - }() + verifyStatus(t, s, StatusReady) + + s.wg.Add(1) require.NoError(t, s.Shutdown()) + verifyStatus(t, s, StatusOnShutdown) + + s.wg.Done() + + verifyStatus(t, s, StatusClosed) +} + +func TestReadinessProbe(t *testing.T) { + s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}} + + require.Equal(t, StatusStarting, s.getStatus()) + + mux := s.MonitoringServeMux() + + req := httptest.NewRequest("GET", "/start", nil) + + r := httptest.NewRecorder() + mux.ServeHTTP(r, req) + require.Equal(t, 503, r.Result().StatusCode) + + s.changeStatus(StatusReady) - require.True(t, <-done, "the accepting loop must be interrupted") + r = httptest.NewRecorder() + mux.ServeHTTP(r, req) + require.Equal(t, 200, r.Result().StatusCode) + + s.changeStatus(StatusOnShutdown) + + r = httptest.NewRecorder() + mux.ServeHTTP(r, req) + require.Equal(t, 503, r.Result().StatusCode) +} + +func TestLivenessProbe(t *testing.T) { + s := &Server{Config: &config.Config{Server: config.DefaultServerConfig}} + mux := s.MonitoringServeMux() + + req := httptest.NewRequest("GET", "/health", nil) + + r := httptest.NewRecorder() + mux.ServeHTTP(r, req) + require.Equal(t, 200, r.Result().StatusCode) } func setupServer(t *testing.T) *Server { @@ -42,8 +81,18 @@ func setupServer(t *testing.T) *Server { cfg := &config.Config{RootDir: "/tmp", GitlabUrl: url, Server: srvCfg} - s := &Server{Config: cfg} - require.NoError(t, s.listen()) + return &Server{Config: cfg} +} + +func verifyStatus(t *testing.T, s *Server, st status) { + for i := 5; i < 500; i+=50 { + if s.getStatus() == st { + break + } + + // Sleep incrementally ~2s in total + time.Sleep(time.Duration(i) * time.Millisecond) + } - return s + require.Equal(t, s.getStatus(), st) } |