summaryrefslogtreecommitdiff
path: root/internal/sshd/sshd.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/sshd.go')
-rw-r--r--internal/sshd/sshd.go85
1 files changed, 45 insertions, 40 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index 8b49712..b918109 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -35,10 +35,40 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ hostKeys []ssh.Signer
+ authorizedKeysClient *authorizedkeys.Client
+}
+
+func NewServer(cfg *config.Config) (*Server, error) {
+ authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ }
+
+ var hostKeys []ssh.Signer
+ for _, filename := range cfg.Server.HostKeyFiles {
+ keyRaw, err := ioutil.ReadFile(filename)
+ if err != nil {
+ log.WithError(err).Warnf("Failed to read host key %v", filename)
+ continue
+ }
+ key, err := ssh.ParsePrivateKey(keyRaw)
+ if err != nil {
+ log.WithError(err).Warnf("Failed to parse host key %v", filename)
+ continue
+ }
+
+ hostKeys = append(hostKeys, key)
+ }
+ if len(hostKeys) == 0 {
+ return nil, fmt.Errorf("No host keys could be loaded, aborting")
+ }
+
+ return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
@@ -47,7 +77,9 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
}
defer s.listener.Close()
- return s.serve(ctx)
+ s.serve(ctx)
+
+ return nil
}
func (s *Server) Shutdown() error {
@@ -100,12 +132,7 @@ func (s *Server) listen() error {
return nil
}
-func (s *Server) serve(ctx context.Context) error {
- sshCfg, err := s.initConfig(ctx)
- if err != nil {
- return err
- }
-
+func (s *Server) serve(ctx context.Context) {
s.changeStatus(StatusReady)
for {
@@ -120,14 +147,12 @@ func (s *Server) serve(ctx context.Context) error {
}
s.wg.Add(1)
- go s.handleConn(ctx, sshCfg, nconn)
+ go s.handleConn(ctx, nconn)
}
s.wg.Wait()
s.changeStatus(StatusClosed)
-
- return nil
}
func (s *Server) changeStatus(st status) {
@@ -143,12 +168,7 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(s.Config)
- if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
- }
-
+func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
sshCfg := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if conn.User() != s.Config.User {
@@ -159,7 +179,7 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
- res, err := authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
+ res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
if err != nil {
return nil, err
}
@@ -173,29 +193,14 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) {
},
}
- var loadedHostKeys uint
- for _, filename := range s.Config.Server.HostKeyFiles {
- keyRaw, err := ioutil.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
- loadedHostKeys++
+ for _, key := range s.hostKeys {
sshCfg.AddHostKey(key)
}
- if loadedHostKeys == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
- return sshCfg, nil
+ return sshCfg
}
-func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) {
+func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
@@ -211,7 +216,7 @@ func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg)
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
if err != nil {
log.WithError(err).Info("Failed to initialize SSH connection")
return