diff options
Diffstat (limited to 'libgo/go/crypto/tls/tls_test.go')
-rw-r--r-- | libgo/go/crypto/tls/tls_test.go | 427 |
1 files changed, 408 insertions, 19 deletions
diff --git a/libgo/go/crypto/tls/tls_test.go b/libgo/go/crypto/tls/tls_test.go index 5cc14278a0..86812f0c97 100644 --- a/libgo/go/crypto/tls/tls_test.go +++ b/libgo/go/crypto/tls/tls_test.go @@ -6,11 +6,16 @@ package tls import ( "bytes" + "crypto/x509" "errors" "fmt" "internal/testenv" "io" + "io/ioutil" + "math" "net" + "os" + "reflect" "strings" "testing" "time" @@ -92,6 +97,7 @@ var keyPairTests = []struct { } func TestX509KeyPair(t *testing.T) { + t.Parallel() var pem []byte for _, test := range keyPairTests { pem = []byte(test.cert + test.key) @@ -146,7 +152,7 @@ func TestX509MixedKeyPair(t *testing.T) { } } -func newLocalListener(t *testing.T) net.Listener { +func newLocalListener(t testing.TB) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { ln, err = net.Listen("tcp6", "[::1]:0") @@ -188,18 +194,25 @@ func TestDialTimeout(t *testing.T) { t.Fatal("DialWithTimeout completed successfully") } - if !strings.Contains(err.Error(), "timed out") { - t.Errorf("resulting error not a timeout: %s", err) + if !isTimeoutError(err) { + t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) } } +func isTimeoutError(err error) bool { + if ne, ok := err.(net.Error); ok { + return ne.Timeout() + } + return false +} + // tests that Conn.Read returns (non-zero, io.EOF) instead of // (non-zero, nil) when a Close (alertCloseNotify) is sitting right // behind the application data in the buffer. func TestConnReadNonzeroAndEOF(t *testing.T) { // This test is racy: it assumes that after a write to a // localhost TCP connection, the peer TCP connection can - // immediately read it. Because it's racy, we skip this test + // immediately read it. Because it's racy, we skip this test // in short mode, and then retry it several times with an // increasing sleep in between our final write (via srv.Close // below) and the following read. @@ -228,8 +241,8 @@ func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { srvCh <- nil return } - serverConfig := *testConfig - srv := Server(sconn, &serverConfig) + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { serr = fmt.Errorf("handshake: %v", err) srvCh <- nil @@ -238,8 +251,8 @@ func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error { srvCh <- srv }() - clientConfig := *testConfig - conn, err := Dial("tcp", ln.Addr().String(), &clientConfig) + clientConfig := testConfig.Clone() + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) if err != nil { t.Fatal(err) } @@ -280,20 +293,22 @@ func TestTLSUniqueMatches(t *testing.T) { for i := 0; i < 2; i++ { sconn, err := ln.Accept() if err != nil { - t.Fatal(err) + t.Error(err) + return } - serverConfig := *testConfig - srv := Server(sconn, &serverConfig) + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { - t.Fatal(err) + t.Error(err) + return } serverTLSUniques <- srv.ConnectionState().TLSUnique } }() - clientConfig := *testConfig + clientConfig := testConfig.Clone() clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) - conn, err := Dial("tcp", ln.Addr().String(), &clientConfig) + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) if err != nil { t.Fatal(err) } @@ -302,7 +317,7 @@ func TestTLSUniqueMatches(t *testing.T) { } conn.Close() - conn, err = Dial("tcp", ln.Addr().String(), &clientConfig) + conn, err = Dial("tcp", ln.Addr().String(), clientConfig) if err != nil { t.Fatal(err) } @@ -381,8 +396,8 @@ func TestConnCloseBreakingWrite(t *testing.T) { srvCh <- nil return } - serverConfig := *testConfig - srv := Server(sconn, &serverConfig) + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) if err := srv.Handshake(); err != nil { serr = fmt.Errorf("handshake: %v", err) srvCh <- nil @@ -401,8 +416,8 @@ func TestConnCloseBreakingWrite(t *testing.T) { Conn: cconn, } - clientConfig := *testConfig - tconn := Client(conn, &clientConfig) + clientConfig := testConfig.Clone() + tconn := Client(conn, clientConfig) if err := tconn.Handshake(); err != nil { t.Fatal(err) } @@ -445,6 +460,220 @@ func TestConnCloseBreakingWrite(t *testing.T) { } } +func TestConnCloseWrite(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + clientDoneChan := make(chan struct{}) + + serverCloseWrite := func() error { + sconn, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %v", err) + } + defer sconn.Close() + + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + return fmt.Errorf("handshake: %v", err) + } + defer srv.Close() + + data, err := ioutil.ReadAll(srv) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + + if err := srv.CloseWrite(); err != nil { + return fmt.Errorf("server CloseWrite: %v", err) + } + + // Wait for clientCloseWrite to finish, so we know we + // tested the CloseWrite before we defer the + // sconn.Close above, which would also cause the + // client to unblock like CloseWrite. + <-clientDoneChan + return nil + } + + clientCloseWrite := func() error { + defer close(clientDoneChan) + + clientConfig := testConfig.Clone() + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + return err + } + if err := conn.Handshake(); err != nil { + return err + } + defer conn.Close() + + if err := conn.CloseWrite(); err != nil { + return fmt.Errorf("client CloseWrite: %v", err) + } + + if _, err := conn.Write([]byte{0}); err != errShutdown { + return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) + } + + data, err := ioutil.ReadAll(conn) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + return nil + } + + errChan := make(chan error, 2) + + go func() { errChan <- serverCloseWrite() }() + go func() { errChan <- clientCloseWrite() }() + + for i := 0; i < 2; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + case <-time.After(10 * time.Second): + t.Fatal("deadlock") + } + } + + // Also test CloseWrite being called before the handshake is + // finished: + { + ln2 := newLocalListener(t) + defer ln2.Close() + + netConn, err := net.Dial("tcp", ln2.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer netConn.Close() + conn := Client(netConn, testConfig.Clone()) + + if err := conn.CloseWrite(); err != errEarlyCloseWrite { + t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) + } + } +} + +func TestCloneFuncFields(t *testing.T) { + const expectedCount = 5 + called := 0 + + c1 := Config{ + Time: func() time.Time { + called |= 1 << 0 + return time.Time{} + }, + GetCertificate: func(*ClientHelloInfo) (*Certificate, error) { + called |= 1 << 1 + return nil, nil + }, + GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) { + called |= 1 << 2 + return nil, nil + }, + GetConfigForClient: func(*ClientHelloInfo) (*Config, error) { + called |= 1 << 3 + return nil, nil + }, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + called |= 1 << 4 + return nil + }, + } + + c2 := c1.Clone() + + c2.Time() + c2.GetCertificate(nil) + c2.GetClientCertificate(nil) + c2.GetConfigForClient(nil) + c2.VerifyPeerCertificate(nil, nil) + + if called != (1<<expectedCount)-1 { + t.Fatalf("expected %d calls but saw calls %b", expectedCount, called) + } +} + +func TestCloneNonFuncFields(t *testing.T) { + var c1 Config + v := reflect.ValueOf(&c1).Elem() + + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + f := v.Field(i) + if !f.CanSet() { + // unexported field; not cloned. + continue + } + + // testing/quick can't handle functions or interfaces and so + // isn't used here. + switch fn := typ.Field(i).Name; fn { + case "Rand": + f.Set(reflect.ValueOf(io.Reader(os.Stdin))) + case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate": + // DeepEqual can't compare functions. If you add a + // function field to this list, you must also change + // TestCloneFuncFields to ensure that the func field is + // cloned. + case "Certificates": + f.Set(reflect.ValueOf([]Certificate{ + {Certificate: [][]byte{{'b'}}}, + })) + case "NameToCertificate": + f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil})) + case "RootCAs", "ClientCAs": + f.Set(reflect.ValueOf(x509.NewCertPool())) + case "ClientSessionCache": + f.Set(reflect.ValueOf(NewLRUClientSessionCache(10))) + case "KeyLogWriter": + f.Set(reflect.ValueOf(io.Writer(os.Stdout))) + case "NextProtos": + f.Set(reflect.ValueOf([]string{"a", "b"})) + case "ServerName": + f.Set(reflect.ValueOf("b")) + case "ClientAuth": + f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) + case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": + f.Set(reflect.ValueOf(true)) + case "MinVersion", "MaxVersion": + f.Set(reflect.ValueOf(uint16(VersionTLS12))) + case "SessionTicketKey": + f.Set(reflect.ValueOf([32]byte{})) + case "CipherSuites": + f.Set(reflect.ValueOf([]uint16{1, 2})) + case "CurvePreferences": + f.Set(reflect.ValueOf([]CurveID{CurveP256})) + case "Renegotiation": + f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) + default: + t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) + } + } + + c2 := c1.Clone() + // DeepEqual also compares unexported fields, thus c2 needs to have run + // serverInit in order to be DeepEqual to c1. Cloning it and discarding + // the result is sufficient. + c2.Clone() + + if !reflect.DeepEqual(&c1, c2) { + t.Errorf("clone failed to copy a field") + } +} + // changeImplConn is a net.Conn which can change its Write and Close // methods. type changeImplConn struct { @@ -466,3 +695,163 @@ func (w *changeImplConn) Close() error { } return w.Conn.Close() } + +func throughput(b *testing.B, totalBytes int64, dynamicRecordSizingDisabled bool) { + ln := newLocalListener(b) + defer ln.Close() + + N := b.N + + // Less than 64KB because Windows appears to use a TCP rwin < 64KB. + // See Issue #15899. + const bufsize = 32 << 10 + + go func() { + buf := make([]byte, bufsize) + for i := 0; i < N; i++ { + sconn, err := ln.Accept() + if err != nil { + // panic rather than synchronize to avoid benchmark overhead + // (cannot call b.Fatal in goroutine) + panic(fmt.Errorf("accept: %v", err)) + } + serverConfig := testConfig.Clone() + serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers + serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + panic(fmt.Errorf("handshake: %v", err)) + } + if _, err := io.CopyBuffer(srv, srv, buf); err != nil { + panic(fmt.Errorf("copy buffer: %v", err)) + } + } + }() + + b.SetBytes(totalBytes) + clientConfig := testConfig.Clone() + clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers + clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled + + buf := make([]byte, bufsize) + chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf)))) + for i := 0; i < N; i++ { + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + b.Fatal(err) + } + for j := 0; j < chunks; j++ { + _, err := conn.Write(buf) + if err != nil { + b.Fatal(err) + } + _, err = io.ReadFull(conn, buf) + if err != nil { + b.Fatal(err) + } + } + conn.Close() + } +} + +func BenchmarkThroughput(b *testing.B) { + for _, mode := range []string{"Max", "Dynamic"} { + for size := 1; size <= 64; size <<= 1 { + name := fmt.Sprintf("%sPacket/%dMB", mode, size) + b.Run(name, func(b *testing.B) { + throughput(b, int64(size<<20), mode == "Max") + }) + } + } +} + +type slowConn struct { + net.Conn + bps int +} + +func (c *slowConn) Write(p []byte) (int, error) { + if c.bps == 0 { + panic("too slow") + } + t0 := time.Now() + wrote := 0 + for wrote < len(p) { + time.Sleep(100 * time.Microsecond) + allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8 + if allowed > len(p) { + allowed = len(p) + } + if wrote < allowed { + n, err := c.Conn.Write(p[wrote:allowed]) + wrote += n + if err != nil { + return wrote, err + } + } + } + return len(p), nil +} + +func latency(b *testing.B, bps int, dynamicRecordSizingDisabled bool) { + ln := newLocalListener(b) + defer ln.Close() + + N := b.N + + go func() { + for i := 0; i < N; i++ { + sconn, err := ln.Accept() + if err != nil { + // panic rather than synchronize to avoid benchmark overhead + // (cannot call b.Fatal in goroutine) + panic(fmt.Errorf("accept: %v", err)) + } + serverConfig := testConfig.Clone() + serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled + srv := Server(&slowConn{sconn, bps}, serverConfig) + if err := srv.Handshake(); err != nil { + panic(fmt.Errorf("handshake: %v", err)) + } + io.Copy(srv, srv) + } + }() + + clientConfig := testConfig.Clone() + clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled + + buf := make([]byte, 16384) + peek := make([]byte, 1) + + for i := 0; i < N; i++ { + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + b.Fatal(err) + } + // make sure we're connected and previous connection has stopped + if _, err := conn.Write(buf[:1]); err != nil { + b.Fatal(err) + } + if _, err := io.ReadFull(conn, peek); err != nil { + b.Fatal(err) + } + if _, err := conn.Write(buf); err != nil { + b.Fatal(err) + } + if _, err = io.ReadFull(conn, peek); err != nil { + b.Fatal(err) + } + conn.Close() + } +} + +func BenchmarkLatency(b *testing.B) { + for _, mode := range []string{"Max", "Dynamic"} { + for _, kbps := range []int{200, 500, 1000, 2000, 5000} { + name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps) + b.Run(name, func(b *testing.B) { + latency(b, kbps*1000, mode == "Max") + }) + } + } +} |