summaryrefslogtreecommitdiff
path: root/internal/sshd/sshd.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/sshd.go')
-rw-r--r--internal/sshd/sshd.go98
1 files changed, 76 insertions, 22 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index b04366e..ef401dc 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -9,6 +9,7 @@ import (
"net"
"strconv"
"time"
+ "sync"
log "github.com/sirupsen/logrus"
@@ -20,28 +21,87 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
)
-func Run(ctx context.Context, cfg *config.Config) error {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
- if err != nil {
- return fmt.Errorf("failed to initialize GitLab client: %w", err)
+type Server struct {
+ Config *config.Config
+
+ onShutdown bool
+ wg sync.WaitGroup
+ listener net.Listener
+}
+
+func (s *Server) ListenAndServe(ctx context.Context) error {
+ if err := s.listen(); err != nil {
+ return err
}
+ defer s.listener.Close()
+
+ return s.serve(ctx)
+}
+
+func (s *Server) Shutdown() error {
+ if s.listener == nil {
+ return nil
+ }
+
+ s.onShutdown = true
+
+ return s.listener.Close()
+}
- sshListener, err := net.Listen("tcp", cfg.Server.Listen)
+func (s *Server) listen() error {
+ sshListener, err := net.Listen("tcp", s.Config.Server.Listen)
if err != nil {
return fmt.Errorf("failed to listen for connection: %w", err)
}
- if cfg.Server.ProxyProtocol {
+
+ if s.Config.Server.ProxyProtocol {
sshListener = &proxyproto.Listener{Listener: sshListener}
log.Info("Proxy protocol is enabled")
}
- defer sshListener.Close()
log.Infof("Listening on %v", sshListener.Addr().String())
+ s.listener = sshListener
+
+ return nil
+}
+
+func (s *Server) serve(ctx context.Context) error {
+ sshCfg, err := s.initConfig(ctx)
+ if err != nil {
+ return err
+ }
+
+ for {
+ nconn, err := s.listener.Accept()
+ if err != nil {
+ if s.onShutdown {
+ break
+ }
+
+ log.Warnf("Failed to accept connection: %v\n", err)
+ continue
+ }
+
+ s.wg.Add(1)
+ go s.handleConn(ctx, sshCfg, nconn)
+ }
+
+ s.wg.Wait()
+
+ return nil
+}
+
+func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(s.Config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != cfg.User {
+ if conn.User() != s.Config.User {
return nil, errors.New("unknown user")
}
if key.Type() == ssh.KeyAlgoDSA {
@@ -64,7 +124,7 @@ func Run(ctx context.Context, cfg *config.Config) error {
}
var loadedHostKeys uint
- for _, filename := range cfg.Server.HostKeyFiles {
+ for _, filename := range s.Config.Server.HostKeyFiles {
keyRaw, err := ioutil.ReadFile(filename)
if err != nil {
log.Warnf("Failed to read host key %v: %v", filename, err)
@@ -79,23 +139,17 @@ func Run(ctx context.Context, cfg *config.Config) error {
sshCfg.AddHostKey(key)
}
if loadedHostKeys == 0 {
- return fmt.Errorf("No host keys could be loaded, aborting")
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
}
- for {
- nconn, err := sshListener.Accept()
- if err != nil {
- log.Warnf("Failed to accept connection: %v\n", err)
- continue
- }
-
- go handleConn(ctx, cfg, sshCfg, nconn)
- }
+ return sshCfg, nil
}
-func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+
+func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
+ defer s.wg.Done()
defer nconn.Close()
// Prevent a panic in a single connection from taking out the whole server
@@ -116,10 +170,10 @@ func handleConn(ctx context.Context, cfg *config.Config, sshCfg *ssh.ServerConfi
go ssh.DiscardRequests(reqs)
- conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr)
+ conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
session := &session{
- cfg: cfg,
+ cfg: s.Config,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
remoteAddr: remoteAddr,