diff options
Diffstat (limited to 'internal/sshd/sshd.go')
-rw-r--r-- | internal/sshd/sshd.go | 85 |
1 files changed, 45 insertions, 40 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 8b49712..b918109 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -35,10 +35,40 @@ const ( type Server struct { Config *config.Config - status status - statusMu sync.Mutex - wg sync.WaitGroup - listener net.Listener + status status + statusMu sync.Mutex + wg sync.WaitGroup + listener net.Listener + hostKeys []ssh.Signer + authorizedKeysClient *authorizedkeys.Client +} + +func NewServer(cfg *config.Config) (*Server, error) { + authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + } + + var hostKeys []ssh.Signer + for _, filename := range cfg.Server.HostKeyFiles { + keyRaw, err := ioutil.ReadFile(filename) + if err != nil { + log.WithError(err).Warnf("Failed to read host key %v", filename) + continue + } + key, err := ssh.ParsePrivateKey(keyRaw) + if err != nil { + log.WithError(err).Warnf("Failed to parse host key %v", filename) + continue + } + + hostKeys = append(hostKeys, key) + } + if len(hostKeys) == 0 { + return nil, fmt.Errorf("No host keys could be loaded, aborting") + } + + return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil } func (s *Server) ListenAndServe(ctx context.Context) error { @@ -47,7 +77,9 @@ func (s *Server) ListenAndServe(ctx context.Context) error { } defer s.listener.Close() - return s.serve(ctx) + s.serve(ctx) + + return nil } func (s *Server) Shutdown() error { @@ -100,12 +132,7 @@ func (s *Server) listen() error { return nil } -func (s *Server) serve(ctx context.Context) error { - sshCfg, err := s.initConfig(ctx) - if err != nil { - return err - } - +func (s *Server) serve(ctx context.Context) { s.changeStatus(StatusReady) for { @@ -120,14 +147,12 @@ func (s *Server) serve(ctx context.Context) error { } s.wg.Add(1) - go s.handleConn(ctx, sshCfg, nconn) + go s.handleConn(ctx, nconn) } s.wg.Wait() s.changeStatus(StatusClosed) - - return nil } func (s *Server) changeStatus(st status) { @@ -143,12 +168,7 @@ func (s *Server) getStatus() status { return s.status } -func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { - authorizedKeysClient, err := authorizedkeys.NewClient(s.Config) - if err != nil { - return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) - } - +func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig { sshCfg := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { if conn.User() != s.Config.User { @@ -159,7 +179,7 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { } ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - res, err := authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) + res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) if err != nil { return nil, err } @@ -173,29 +193,14 @@ func (s *Server) initConfig(ctx context.Context) (*ssh.ServerConfig, error) { }, } - var loadedHostKeys uint - for _, filename := range s.Config.Server.HostKeyFiles { - keyRaw, err := ioutil.ReadFile(filename) - if err != nil { - log.WithError(err).Warnf("Failed to read host key %v", filename) - continue - } - key, err := ssh.ParsePrivateKey(keyRaw) - if err != nil { - log.WithError(err).Warnf("Failed to parse host key %v", filename) - continue - } - loadedHostKeys++ + for _, key := range s.hostKeys { sshCfg.AddHostKey(key) } - if loadedHostKeys == 0 { - return nil, fmt.Errorf("No host keys could be loaded, aborting") - } - return sshCfg, nil + return sshCfg } -func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn net.Conn) { +func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() defer s.wg.Done() @@ -211,7 +216,7 @@ func (s *Server) handleConn(ctx context.Context, sshCfg *ssh.ServerConfig, nconn ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() - sconn, chans, reqs, err := ssh.NewServerConn(nconn, sshCfg) + sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx)) if err != nil { log.WithError(err).Info("Failed to initialize SSH connection") return |