diff options
Diffstat (limited to 'libgo/go/database/sql/sql_test.go')
-rw-r--r-- | libgo/go/database/sql/sql_test.go | 733 |
1 files changed, 700 insertions, 33 deletions
diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go index 8ec70d99b0..450e5f1f8c 100644 --- a/libgo/go/database/sql/sql_test.go +++ b/libgo/go/database/sql/sql_test.go @@ -5,6 +5,7 @@ package sql import ( + "context" "database/sql/driver" "errors" "fmt" @@ -13,6 +14,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "testing" "time" ) @@ -23,6 +25,17 @@ func init() { c *driverConn } freedFrom := make(map[dbConn]string) + var mu sync.Mutex + getFreedFrom := func(c dbConn) string { + mu.Lock() + defer mu.Unlock() + return freedFrom[c] + } + setFreedFrom := func(c dbConn, s string) { + mu.Lock() + defer mu.Unlock() + freedFrom[c] = s + } putConnHook = func(db *DB, c *driverConn) { idx := -1 for i, v := range db.freeConn { @@ -35,10 +48,10 @@ func init() { // print before panic, as panic may get lost due to conflicting panic // (all goroutines asleep) elsewhere, since we might not unlock // the mutex in freeConn here. - println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack()) + println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack()) panic("double free of conn.") } - freedFrom[dbConn{db, c}] = stack() + setFreedFrom(dbConn{db, c}, stack()) } } @@ -140,11 +153,13 @@ func closeDB(t testing.TB, db *DB) { if err != nil { t.Fatalf("error closing DB: %v", err) } - db.mu.Lock() - count := db.numOpen - db.mu.Unlock() - if count != 0 { - t.Fatalf("%d connections still open after closing DB", db.numOpen) + + var numOpen int + if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool { + numOpen = db.numOpenConns() + return numOpen == 0 + }) { + t.Fatalf("%d connections still open after closing DB", numOpen) } } @@ -182,6 +197,12 @@ func (db *DB) numFreeConns() int { return len(db.freeConn) } +func (db *DB) numOpenConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return db.numOpen +} + // clearAllConns closes all connections in db. func (db *DB) clearAllConns(t *testing.T) { db.SetMaxIdleConns(0) @@ -260,6 +281,338 @@ func TestQuery(t *testing.T) { } } +// TestQueryContext tests canceling the context while scanning the rows. +func TestQueryContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := db.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + index := 0 + for rows.Next() { + if index == 2 { + cancel() + waitForRowsClose(t, rows, 5*time.Second) + } + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + if index == 2 { + break + } + t.Fatalf("Scan: %v", err) + } + if index == 2 && err == nil { + t.Fatal("expected an error on last scan") + } + got = append(got, r) + index++ + } + select { + case <-ctx.Done(): + if err := ctx.Err(); err != context.Canceled { + t.Fatalf("context err = %v; want context.Canceled") + } + default: + t.Fatalf("context err = nil; want context.Canceled") + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + waitForRowsClose(t, rows, 5*time.Second) + waitForFree(t, db, 5*time.Second, 1) + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} + +// waitForFree checks db.numFreeConns until either it equals want or +// the maxWait time elapses. +func waitForFree(t *testing.T, db *DB, maxWait time.Duration, want int) { + var numFree int + if !waitCondition(maxWait, 5*time.Millisecond, func() bool { + numFree = db.numFreeConns() + return numFree == want + }) { + t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want) + } +} + +func waitForRowsClose(t *testing.T, rows *Rows, maxWait time.Duration) { + if !waitCondition(maxWait, 5*time.Millisecond, func() bool { + rows.closemu.RLock() + defer rows.closemu.RUnlock() + return rows.closed + }) { + t.Fatal("failed to close rows") + } +} + +// TestQueryContextWait ensures that rows and all internal statements are closed when +// a query context is closed during execution. +func TestQueryContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + // TODO(kardianos): convert this from using a timeout to using an explicit + // cancel when the query signals that is is "executing" the query. + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err := db.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + // Verify closed rows connection after error condition. + waitForFree(t, db, 5*time.Second, 1) + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + // TODO(kardianos): if the context timeouts before the db.QueryContext + // executes this check may fail. After adjusting how the context + // is canceled above revert this back to a Fatal error. + t.Logf("executed %d Prepare statements; want 1", prepares) + } +} + +// TestTxContextWait tests the transaction behavior when the tx context is canceled +// during execution of the query. +func TestTxContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + // Guard against the context being canceled before BeginTx completes. + if err == context.DeadlineExceeded { + t.Skip("tx context canceled prior to first use") + } + t.Fatal(err) + } + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err = tx.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + waitForFree(t, db, 5*time.Second, 0) +} + +func TestMultiResultSetQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row1 struct { + age int + name string + } + type row2 struct { + name string + } + got1 := []row1{} + for rows.Next() { + var r row1 + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got1 = append(got1, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want1 := []row1{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got1, want1) { + t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1) + } + + if !rows.NextResultSet() { + t.Errorf("expected another result set") + } + + got2 := []row2{} + for rows.Next() { + var r row2 + err = rows.Scan(&r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got2 = append(got2, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want2 := []row2{ + {name: "Alice"}, + {name: "Bob"}, + {name: "Chris"}, + } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2) + } + if rows.NextResultSet() { + t.Errorf("expected no more result sets") + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + waitForFree(t, db, 5*time.Second, 1) + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestQueryNamedArg(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query( + // Ensure the name and age parameters only match on placeholder name, not position. + "SELECT|people|age,name|name=?name,age=?age", + Named("age", 2), + Named("name", "Bob"), + ) + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestPoolExhaustOnCancel(t *testing.T) { + if testing.Short() { + t.Skip("long test") + } + db := newTestDB(t, "people") + defer closeDB(t, db) + + max := 3 + + db.SetMaxOpenConns(max) + + // First saturate the connection pool. + // Then start new requests for a connection that is cancelled after it is requested. + + var saturate, saturateDone sync.WaitGroup + saturate.Add(max) + saturateDone.Add(max) + + for i := 0; i < max; i++ { + go func() { + saturate.Done() + rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|") + if err != nil { + t.Fatalf("Query: %v", err) + } + rows.Close() + saturateDone.Done() + }() + } + + saturate.Wait() + + // Now cancel the request while it is waiting. + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + + for i := 0; i < max; i++ { + ctxReq, cancelReq := context.WithCancel(ctx) + go func() { + time.Sleep(time.Millisecond * 100) + cancelReq() + }() + err := db.PingContext(ctxReq) + if err != context.Canceled { + t.Fatalf("PingContext (Exhaust): %v", err) + } + } + + saturateDone.Wait() + + // Now try to open a normal connection. + err := db.PingContext(ctx) + if err != nil { + t.Fatalf("PingContext (Normal): %v", err) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -317,6 +670,56 @@ func TestRowsColumns(t *testing.T) { } } +func TestRowsColumnTypes(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + types[i] = st + } + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + ct := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + ct++ + if ct == 0 { + if values[0].(string) != "Bob" { + t.Errorf("Expected Bob, got %v", values[0]) + } + if values[1].(int) != 2 { + t.Errorf("Expected 2, got %v", values[1]) + } + } + } + if ct != 3 { + t.Errorf("expected 3 rows, got %d", ct) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } +} + func TestQueryRow(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -367,6 +770,37 @@ func TestQueryRow(t *testing.T) { } } +func TestTxRollbackCommitErr(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + err = tx.Rollback() + if err != nil { + t.Errorf("expected nil error from Rollback; got %v", err) + } + err = tx.Commit() + if err != ErrTxDone { + t.Errorf("expected %q from Commit; got %q", ErrTxDone, err) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + err = tx.Commit() + if err != nil { + t.Errorf("expected nil error from Commit; got %v", err) + } + err = tx.Rollback() + if err != ErrTxDone { + t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err) + } +} + func TestStatementErrorAfterClose(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -439,7 +873,7 @@ func TestStatementClose(t *testing.T) { msg string }{ {&Stmt{stickyErr: want}, "stickyErr not propagated"}, - {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, } for _, test := range tests { if err := test.stmt.Close(); err != want { @@ -513,8 +947,8 @@ func TestExec(t *testing.T) { {[]interface{}{7, 9}, ""}, // Invalid conversions: - {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"}, - {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"}, + {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"}, + {[]interface{}{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`}, // Wrong number of args: {[]interface{}{}, "sql: expected 2 arguments, got 0"}, @@ -788,6 +1222,24 @@ func TestQueryRowClosingStmt(t *testing.T) { } } +var atomicRowsCloseHook atomic.Value // of func(*Rows, *error) + +func init() { + rowsCloseHook = func() func(*Rows, *error) { + fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error)) + return fn + } +} + +func setRowsCloseHook(fn func(*Rows, *error)) { + if fn == nil { + // Can't change an atomic.Value back to nil, so set it to this + // no-op func instead. + fn = func(*Rows, *error) {} + } + atomicRowsCloseHook.Store(fn) +} + // Test issue 6651 func TestIssue6651(t *testing.T) { db := newTestDB(t, "people") @@ -800,6 +1252,7 @@ func TestIssue6651(t *testing.T) { return fmt.Errorf(want) } defer func() { rowsCursorNextHook = nil }() + err := db.QueryRow("SELECT|people|name|").Scan(&v) if err == nil || err.Error() != want { t.Errorf("error = %q; want %q", err, want) @@ -807,10 +1260,10 @@ func TestIssue6651(t *testing.T) { rowsCursorNextHook = nil want = "error in rows.Close" - rowsCloseHook = func(rows *Rows, err *error) { + setRowsCloseHook(func(rows *Rows, err *error) { *err = fmt.Errorf(want) - } - defer func() { rowsCloseHook = nil }() + }) + defer setRowsCloseHook(nil) err = db.QueryRow("SELECT|people|name|").Scan(&v) if err == nil || err.Error() != want { t.Errorf("error = %q; want %q", err, want) @@ -911,7 +1364,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) { if err == nil { // TODO: this test fails, but it's just because // fakeConn implements the optional Execer interface, - // so arguably this is the correct behavior. But + // so arguably this is the correct behavior. But // maybe I should flesh out the fakeConn.Exec // implementation so this properly fails. // t.Errorf("expected error inserting nil name with Exec") @@ -1159,17 +1612,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) { db.SetMaxOpenConns(3) - conn0, err := db.conn(cachedOrNewConn) + ctx := context.Background() + + conn0, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } - conn1, err := db.conn(cachedOrNewConn) + conn1, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } - conn2, err := db.conn(cachedOrNewConn) + conn2, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } @@ -1203,7 +1658,11 @@ func TestPendingConnsAfterErr(t *testing.T) { tryOpen = maxOpen*2 + 2 ) - db := newTestDB(t, "people") + // No queries will be run. + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } defer closeDB(t, db) defer func() { for k, v := range db.lastPut { @@ -1215,31 +1674,31 @@ func TestPendingConnsAfterErr(t *testing.T) { db.SetMaxIdleConns(0) errOffline := errors.New("db offline") + defer func() { setHookOpenErr(nil) }() errs := make(chan error, tryOpen) - unblock := make(chan struct{}) + var opening sync.WaitGroup + opening.Add(tryOpen) + setHookOpenErr(func() error { - <-unblock // block until all connections are in flight + // Wait for all connections to enqueue. + opening.Wait() return errOffline }) - var opening sync.WaitGroup - opening.Add(tryOpen) for i := 0; i < tryOpen; i++ { go func() { opening.Done() // signal one connection is in flight - _, err := db.Exec("INSERT|people|name=Julia,age=19") + _, err := db.Exec("will never run") errs <- err }() } - opening.Wait() // wait for all workers to begin running - time.Sleep(10 * time.Millisecond) // make extra sure all workers are blocked - close(unblock) // let all workers proceed + opening.Wait() // wait for all workers to begin running - const timeout = 100 * time.Millisecond + const timeout = 5 * time.Second to := time.NewTimer(timeout) defer to.Stop() @@ -1254,6 +1713,24 @@ func TestPendingConnsAfterErr(t *testing.T) { t.Fatalf("orphaned connection request(s), still waiting after %v", timeout) } } + + // Wait a reasonable time for the database to close all connections. + tick := time.NewTicker(3 * time.Millisecond) + defer tick.Stop() + for { + select { + case <-tick.C: + db.mu.Lock() + if db.numOpen == 0 { + db.mu.Unlock() + return + } + db.mu.Unlock() + case <-to.C: + // Closing the database will check for numOpen and fail the test. + return + } + } } func TestSingleOpenConn(t *testing.T) { @@ -1459,7 +1936,9 @@ func TestStmtCloseDeps(t *testing.T) { db.dumpDeps(t) } - if len(stmt.css) > nquery { + if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool { + return len(stmt.css) <= nquery + }) { t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery) } @@ -1591,7 +2070,7 @@ func TestStmtCloseOrder(t *testing.T) { _, err := db.Query("SELECT|non_existent|name|") if err == nil { - t.Fatal("Quering non-existent table should fail") + t.Fatal("Querying non-existent table should fail") } } @@ -1615,6 +2094,8 @@ func TestManyErrBadConn(t *testing.T) { } }() + db.mu.Lock() + defer db.mu.Unlock() if db.numOpen != nconn { t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn) } else if len(db.freeConn) != nconn { @@ -2203,10 +2684,10 @@ func TestIssue6081(t *testing.T) { if err != nil { t.Fatal(err) } - rowsCloseHook = func(rows *Rows, err *error) { + setRowsCloseHook(func(rows *Rows, err *error) { *err = driver.ErrBadConn - } - defer func() { rowsCloseHook = nil }() + }) + defer setRowsCloseHook(nil) for i := 0; i < 10; i++ { rows, err := stmt.Query() if err != nil { @@ -2234,6 +2715,100 @@ func TestIssue6081(t *testing.T) { } } +// TestIssue18429 attempts to stress rolling back the transaction from a +// context cancel while simultaneously calling Tx.Rollback. Rolling back from a +// context happens concurrently so tx.rollback and tx.Commit must guard against +// double entry. +// +// In the test, a context is canceled while the query is in process so +// the internal rollback will run concurrently with the explicitly called +// Tx.Rollback. +func TestIssue18429(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx := context.Background() + sem := make(chan bool, 20) + var wg sync.WaitGroup + + const milliWait = 30 + + for i := 0; i < 100; i++ { + sem <- true + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String() + + ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return + } + // This is expected to give a cancel error many, but not all the time. + // Test failure will happen with a panic or other race condition being + // reported. + rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") + if rows != nil { + rows.Close() + } + // This call will race with the context cancel rollback to complete + // if the rollback itself isn't guarded. + tx.Rollback() + }() + } + wg.Wait() +} + +// TestIssue18719 closes the context right before use. The sql.driverConn +// will nil out the ci on close in a lock, but if another process uses it right after +// it will panic with on the nil ref. +// +// See https://golang.org/cl/35550 . +func TestIssue18719(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + hookTxGrabConn = func() { + cancel() + + // Wait for the context to cancel and tx to rollback. + for tx.isDone() == false { + time.Sleep(time.Millisecond * 3) + } + } + defer func() { hookTxGrabConn = nil }() + + // This call will grab the connection and cancel the context + // after it has done so. Code after must deal with the canceled state. + rows, err := tx.QueryContext(ctx, "SELECT|people|name|") + if err != nil { + rows.Close() + t.Fatalf("expected error %v but got %v", nil, err) + } + + // Rows may be ignored because it will be closed when the context is canceled. + + // Do not explicitly rollback. The rollback will happen from the + // canceled context. + + cancel() + waitForRowsClose(t, rows, 5*time.Second) +} + func TestConcurrency(t *testing.T) { doConcurrentTest(t, new(concurrentDBQueryTest)) doConcurrentTest(t, new(concurrentDBExecTest)) @@ -2277,7 +2852,8 @@ func TestConnectionLeak(t *testing.T) { go func() { r, err := db.Query("SELECT|people|name|") if err != nil { - t.Fatal(err) + t.Error(err) + return } r.Close() wg.Done() @@ -2297,6 +2873,97 @@ func TestConnectionLeak(t *testing.T) { wg.Wait() } +// badConn implements a bad driver.Conn, for TestBadDriver. +// The Exec method panics. +type badConn struct{} + +func (bc badConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("badConn Prepare") +} + +func (bc badConn) Close() error { + return nil +} + +func (bc badConn) Begin() (driver.Tx, error) { + return nil, errors.New("badConn Begin") +} + +func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) { + panic("badConn.Exec") +} + +// badDriver is a driver.Driver that uses badConn. +type badDriver struct{} + +func (bd badDriver) Open(name string) (driver.Conn, error) { + return badConn{}, nil +} + +// Issue 15901. +func TestBadDriver(t *testing.T) { + Register("bad", badDriver{}) + db, err := Open("bad", "ignored") + if err != nil { + t.Fatal(err) + } + defer func() { + if r := recover(); r == nil { + t.Error("expected panic") + } else { + if want := "badConn.Exec"; r.(string) != want { + t.Errorf("panic was %v, expected %v", r, want) + } + } + }() + defer db.Close() + db.Exec("ignored") +} + +type pingDriver struct { + fails bool +} + +type pingConn struct { + badConn + driver *pingDriver +} + +var pingError = errors.New("Ping failed") + +func (pc pingConn) Ping(ctx context.Context) error { + if pc.driver.fails { + return pingError + } + return nil +} + +var _ driver.Pinger = pingConn{} + +func (pd *pingDriver) Open(name string) (driver.Conn, error) { + return pingConn{driver: pd}, nil +} + +func TestPing(t *testing.T) { + driver := &pingDriver{} + Register("ping", driver) + + db, err := Open("ping", "ignored") + if err != nil { + t.Fatal(err) + } + + if err := db.Ping(); err != nil { + t.Errorf("err was %#v, expected nil", err) + return + } + + driver.fails = true + if err := db.Ping(); err != pingError { + t.Errorf("err was %#v, expected pingError", err) + } +} + func BenchmarkConcurrentDBExec(b *testing.B) { b.ReportAllocs() ct := new(concurrentDBExecTest) |