diff options
Diffstat (limited to 'internal/sshd/sshd.go')
-rw-r--r-- | internal/sshd/sshd.go | 75 |
1 files changed, 9 insertions, 66 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index de5fbd4..ff9e765 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -2,13 +2,9 @@ package sshd import ( "context" - "encoding/base64" - "errors" "fmt" "net" "net/http" - "os" - "strconv" "sync" "time" @@ -16,7 +12,6 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/internal/config" - "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" @@ -35,40 +30,20 @@ const ( type Server struct { Config *config.Config - status status - statusMu sync.Mutex - wg sync.WaitGroup - listener net.Listener - hostKeys []ssh.Signer - authorizedKeysClient *authorizedkeys.Client + status status + statusMu sync.Mutex + wg sync.WaitGroup + listener net.Listener + serverConfig *serverConfig } func NewServer(cfg *config.Config) (*Server, error) { - authorizedKeysClient, err := authorizedkeys.NewClient(cfg) + serverConfig, err := newServerConfig(cfg) if err != nil { - return nil, fmt.Errorf("failed to initialize GitLab client: %w", err) + return nil, err } - var hostKeys []ssh.Signer - for _, filename := range cfg.Server.HostKeyFiles { - keyRaw, err := os.ReadFile(filename) - if err != nil { - log.WithError(err).Warnf("Failed to read host key %v", filename) - continue - } - key, err := ssh.ParsePrivateKey(keyRaw) - if err != nil { - log.WithError(err).Warnf("Failed to parse host key %v", filename) - continue - } - - hostKeys = append(hostKeys, key) - } - if len(hostKeys) == 0 { - return nil, fmt.Errorf("No host keys could be loaded, aborting") - } - - return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil + return &Server{Config: cfg, serverConfig: serverConfig}, nil } func (s *Server) ListenAndServe(ctx context.Context) error { @@ -168,38 +143,6 @@ func (s *Server) getStatus() status { return s.status } -func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig { - sshCfg := &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if conn.User() != s.Config.User { - return nil, errors.New("unknown user") - } - if key.Type() == ssh.KeyAlgoDSA { - return nil, errors.New("DSA is prohibited") - } - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal())) - if err != nil { - return nil, err - } - - return &ssh.Permissions{ - // Record the public key used for authentication. - Extensions: map[string]string{ - "key-id": strconv.FormatInt(res.Id, 10), - }, - }, nil - }, - } - - for _, key := range s.hostKeys { - sshCfg.AddHostKey(key) - } - - return sshCfg -} - func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { remoteAddr := nconn.RemoteAddr().String() @@ -216,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID())) defer cancel() - sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx)) + sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx)) if err != nil { log.WithError(err).Info("Failed to initialize SSH connection") return |