diff options
author | Nick Thomas <nick@gitlab.com> | 2021-04-23 15:45:26 +0100 |
---|---|---|
committer | Nick Thomas <nick@gitlab.com> | 2021-04-23 15:53:59 +0100 |
commit | 8cd3599f820b3b877626c5802471bfb85218ab16 (patch) | |
tree | 90aeb2d92d5aadcbfaa0f1e2cb812df5cda859e3 | |
parent | 39792693a2a2d06669103714e7fa9da83b0e9b12 (diff) | |
download | gitlab-shell-511-be-safe-against-panics.tar.gz |
sshd: Recover from per-session and per-connection panics511-be-safe-against-panics
Without this, a failure in a single session could take out a whole
connection, or a failure in a single connection could take out the
whole server.
-rw-r--r-- | internal/sshd/connection.go | 12 | ||||
-rw-r--r-- | internal/sshd/connection_test.go | 49 | ||||
-rw-r--r-- | internal/sshd/sshd.go | 13 |
3 files changed, 71 insertions, 3 deletions
diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index c8d1456..a9b9e97 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -50,14 +50,16 @@ var ( type connection struct { begin time.Time concurrentSessions *semaphore.Weighted + remoteAddr string } type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) -func newConnection(maxSessions int64) *connection { +func newConnection(maxSessions int64, remoteAddr string) *connection { return &connection{ begin: time.Now(), concurrentSessions: semaphore.NewWeighted(maxSessions), + remoteAddr: remoteAddr, } } @@ -83,6 +85,14 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha go func() { defer c.concurrentSessions.Release(1) + + // Prevent a panic in a single session from taking out the whole server + defer func() { + if err := recover(); err != nil { + log.Warnf("panic handling session from %s: recovered: %#+v", c.remoteAddr, err) + } + }() + handler(ctx, channel, requests) }() } diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go new file mode 100644 index 0000000..03e9209 --- /dev/null +++ b/internal/sshd/connection_test.go @@ -0,0 +1,49 @@ +package sshd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +type fakeNewChannel struct { + channelType string + extraData []byte +} + +func (f *fakeNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { + return nil, nil, nil +} + +func (f *fakeNewChannel) Reject(reason ssh.RejectionReason, message string) error { + return nil +} + +func (f *fakeNewChannel) ChannelType() string { + return f.channelType +} + +func (f *fakeNewChannel) ExtraData() []byte { + return f.extraData +} + +func TestPanicDuringSessionIsRecovered(t *testing.T) { + numSessions := 0 + conn := newConnection(1, "127.0.0.1:50000") + + newChannel := &fakeNewChannel{channelType: "session"} + chans := make(chan ssh.NewChannel, 1) + chans <- newChannel + + require.NotPanics(t, func() { + conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) { + numSessions += 1 + close(chans) + panic("This is a panic") + }) + }) + + require.Equal(t, numSessions, 1) +} diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index f046d60..a9e797b 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -93,8 +93,17 @@ func Run(cfg *config.Config) error { } func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { + remoteAddr := nconn.RemoteAddr().String() + defer nconn.Close() + // Prevent a panic in a single connection from taking out the whole server + defer func() { + if err := recover(); err != nil { + log.Warnf("panic handling connection from %s: recovered: %#+v", remoteAddr, err) + } + }() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -106,13 +115,13 @@ func handleConn(cfg *config.Config, sshCfg *ssh.ServerConfig, nconn net.Conn) { go ssh.DiscardRequests(reqs) - conn := newConnection(cfg.Server.ConcurrentSessionsLimit) + conn := newConnection(cfg.Server.ConcurrentSessionsLimit, remoteAddr) conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) { session := &session{ cfg: cfg, channel: channel, gitlabKeyId: sconn.Permissions.Extensions["key-id"], - remoteAddr: nconn.RemoteAddr().(*net.TCPAddr).String(), + remoteAddr: remoteAddr, } session.handle(ctx, requests) |