summaryrefslogtreecommitdiff
path: root/internal/sshd/sshd.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sshd/sshd.go')
-rw-r--r--internal/sshd/sshd.go75
1 files changed, 9 insertions, 66 deletions
diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go
index de5fbd4..ff9e765 100644
--- a/internal/sshd/sshd.go
+++ b/internal/sshd/sshd.go
@@ -2,13 +2,9 @@ package sshd
import (
"context"
- "encoding/base64"
- "errors"
"fmt"
"net"
"net/http"
- "os"
- "strconv"
"sync"
"time"
@@ -16,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
- "gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet/authorizedkeys"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
@@ -35,40 +30,20 @@ const (
type Server struct {
Config *config.Config
- status status
- statusMu sync.Mutex
- wg sync.WaitGroup
- listener net.Listener
- hostKeys []ssh.Signer
- authorizedKeysClient *authorizedkeys.Client
+ status status
+ statusMu sync.Mutex
+ wg sync.WaitGroup
+ listener net.Listener
+ serverConfig *serverConfig
}
func NewServer(cfg *config.Config) (*Server, error) {
- authorizedKeysClient, err := authorizedkeys.NewClient(cfg)
+ serverConfig, err := newServerConfig(cfg)
if err != nil {
- return nil, fmt.Errorf("failed to initialize GitLab client: %w", err)
+ return nil, err
}
- var hostKeys []ssh.Signer
- for _, filename := range cfg.Server.HostKeyFiles {
- keyRaw, err := os.ReadFile(filename)
- if err != nil {
- log.WithError(err).Warnf("Failed to read host key %v", filename)
- continue
- }
- key, err := ssh.ParsePrivateKey(keyRaw)
- if err != nil {
- log.WithError(err).Warnf("Failed to parse host key %v", filename)
- continue
- }
-
- hostKeys = append(hostKeys, key)
- }
- if len(hostKeys) == 0 {
- return nil, fmt.Errorf("No host keys could be loaded, aborting")
- }
-
- return &Server{Config: cfg, authorizedKeysClient: authorizedKeysClient, hostKeys: hostKeys}, nil
+ return &Server{Config: cfg, serverConfig: serverConfig}, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
@@ -168,38 +143,6 @@ func (s *Server) getStatus() status {
return s.status
}
-func (s *Server) serverConfig(ctx context.Context) *ssh.ServerConfig {
- sshCfg := &ssh.ServerConfig{
- PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
- if conn.User() != s.Config.User {
- return nil, errors.New("unknown user")
- }
- if key.Type() == ssh.KeyAlgoDSA {
- return nil, errors.New("DSA is prohibited")
- }
- ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
- res, err := s.authorizedKeysClient.GetByKey(ctx, base64.RawStdEncoding.EncodeToString(key.Marshal()))
- if err != nil {
- return nil, err
- }
-
- return &ssh.Permissions{
- // Record the public key used for authentication.
- Extensions: map[string]string{
- "key-id": strconv.FormatInt(res.Id, 10),
- },
- }, nil
- },
- }
-
- for _, key := range s.hostKeys {
- sshCfg.AddHostKey(key)
- }
-
- return sshCfg
-}
-
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr := nconn.RemoteAddr().String()
@@ -216,7 +159,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
- sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig(ctx))
+ sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
log.WithError(err).Info("Failed to initialize SSH connection")
return