summaryrefslogtreecommitdiff
path: root/libgo/go/database/sql
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/database/sql')
-rw-r--r--libgo/go/database/sql/convert.go100
-rw-r--r--libgo/go/database/sql/convert_test.go95
-rw-r--r--libgo/go/database/sql/ctxutil.go149
-rw-r--r--libgo/go/database/sql/driver/driver.go205
-rw-r--r--libgo/go/database/sql/driver/types.go63
-rw-r--r--libgo/go/database/sql/driver/types_test.go16
-rw-r--r--libgo/go/database/sql/fakedb_test.go384
-rw-r--r--libgo/go/database/sql/sql.go1083
-rw-r--r--libgo/go/database/sql/sql_test.go733
9 files changed, 2388 insertions, 440 deletions
diff --git a/libgo/go/database/sql/convert.go b/libgo/go/database/sql/convert.go
index 740fd9d6e7..ea2f377810 100644
--- a/libgo/go/database/sql/convert.go
+++ b/libgo/go/database/sql/convert.go
@@ -13,16 +13,36 @@ import (
"reflect"
"strconv"
"time"
+ "unicode"
+ "unicode/utf8"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
+func describeNamedValue(nv *driver.NamedValue) string {
+ if len(nv.Name) == 0 {
+ return fmt.Sprintf("$%d", nv.Ordinal)
+ }
+ return fmt.Sprintf("with name %q", nv.Name)
+}
+
+func validateNamedValueName(name string) error {
+ if len(name) == 0 {
+ return nil
+ }
+ r, _ := utf8.DecodeRuneInString(name)
+ if unicode.IsLetter(r) {
+ return nil
+ }
+ return fmt.Errorf("name %q does not begin with a letter", name)
+}
+
// driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
- dargs := make([]driver.Value, len(args))
+func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+ nvargs := make([]driver.NamedValue, len(args))
var si driver.Stmt
if ds != nil {
si = ds.si
@@ -33,26 +53,45 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
if !ok {
for n, arg := range args {
var err error
- dargs[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
+ nv := &nvargs[n]
+ nv.Ordinal = n + 1
+ if np, ok := arg.(NamedArg); ok {
+ if err := validateNamedValueName(np.Name); err != nil {
+ return nil, err
+ }
+ arg = np.Value
+ nvargs[n].Name = np.Name
+ }
+ nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
+
if err != nil {
- return nil, fmt.Errorf("sql: converting Exec argument #%d's type: %v", n, err)
+ return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err)
}
}
- return dargs, nil
+ return nvargs, nil
}
// Let the Stmt convert its own arguments.
for n, arg := range args {
+ nv := &nvargs[n]
+ nv.Ordinal = n + 1
+ if np, ok := arg.(NamedArg); ok {
+ if err := validateNamedValueName(np.Name); err != nil {
+ return nil, err
+ }
+ arg = np.Value
+ nv.Name = np.Name
+ }
// First, see if the value itself knows how to convert
- // itself to a driver type. For example, a NullString
+ // itself to a driver type. For example, a NullString
// struct changing into a string or nil.
- if svi, ok := arg.(driver.Valuer); ok {
- sv, err := svi.Value()
+ if vr, ok := arg.(driver.Valuer); ok {
+ sv, err := callValuerValue(vr)
if err != nil {
- return nil, fmt.Errorf("sql: argument index %d from Value: %v", n, err)
+ return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err)
}
if !driver.IsValue(sv) {
- return nil, fmt.Errorf("sql: argument index %d: non-subset type %T returned from Value", n, sv)
+ return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv)
}
arg = sv
}
@@ -66,18 +105,18 @@ func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
// same error.
var err error
ds.Lock()
- dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
+ nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock()
if err != nil {
- return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
+ return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
}
- if !driver.IsValue(dargs[n]) {
- return nil, fmt.Errorf("sql: driver ColumnConverter error converted %T to unsupported type %T",
- arg, dargs[n])
+ if !driver.IsValue(nv.Value) {
+ return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T",
+ describeNamedValue(nv), arg, nv.Value)
}
}
- return dargs, nil
+ return nvargs, nil
}
// convertAssign copies to dest the value in src, converting it if possible.
@@ -217,7 +256,12 @@ func convertAssign(dest, src interface{}) error {
dv := reflect.Indirect(dpv)
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
- dv.Set(sv)
+ switch b := src.(type) {
+ case []byte:
+ dv.Set(reflect.ValueOf(cloneBytes(b)))
+ default:
+ dv.Set(sv)
+ }
return nil
}
@@ -325,3 +369,25 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
}
return
}
+
+var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+// Issue 8415.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This function is mirrored in the database/sql/driver package.
+func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
diff --git a/libgo/go/database/sql/convert_test.go b/libgo/go/database/sql/convert_test.go
index 342875e190..4dfab1f6be 100644
--- a/libgo/go/database/sql/convert_test.go
+++ b/libgo/go/database/sql/convert_test.go
@@ -9,6 +9,7 @@ import (
"fmt"
"reflect"
"runtime"
+ "strings"
"testing"
"time"
)
@@ -377,3 +378,97 @@ func TestRawBytesAllocs(t *testing.T) {
t.Fatalf("allocs = %v; want max 1", n)
}
}
+
+// https://github.com/golang/go/issues/13905
+func TestUserDefinedBytes(t *testing.T) {
+ type userDefinedBytes []byte
+ var u userDefinedBytes
+ v := []byte("foo")
+
+ convertAssign(&u, v)
+ if &u[0] == &v[0] {
+ t.Fatal("userDefinedBytes got potentially dirty driver memory")
+ }
+}
+
+type Valuer_V string
+
+func (v Valuer_V) Value() (driver.Value, error) {
+ return strings.ToUpper(string(v)), nil
+}
+
+type Valuer_P string
+
+func (p *Valuer_P) Value() (driver.Value, error) {
+ if p == nil {
+ return "nil-to-str", nil
+ }
+ return strings.ToUpper(string(*p)), nil
+}
+
+func TestDriverArgs(t *testing.T) {
+ var nilValuerVPtr *Valuer_V
+ var nilValuerPPtr *Valuer_P
+ var nilStrPtr *string
+ tests := []struct {
+ args []interface{}
+ want []driver.NamedValue
+ }{
+ 0: {
+ args: []interface{}{Valuer_V("foo")},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: "FOO",
+ },
+ },
+ },
+ 1: {
+ args: []interface{}{nilValuerVPtr},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: nil,
+ },
+ },
+ },
+ 2: {
+ args: []interface{}{nilValuerPPtr},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: "nil-to-str",
+ },
+ },
+ },
+ 3: {
+ args: []interface{}{"plain-str"},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: "plain-str",
+ },
+ },
+ },
+ 4: {
+ args: []interface{}{nilStrPtr},
+ want: []driver.NamedValue{
+ driver.NamedValue{
+ Ordinal: 1,
+ Value: nil,
+ },
+ },
+ },
+ }
+ for i, tt := range tests {
+ ds := new(driverStmt)
+ got, err := driverArgs(ds, tt.args)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
+ }
+ }
+}
diff --git a/libgo/go/database/sql/ctxutil.go b/libgo/go/database/sql/ctxutil.go
new file mode 100644
index 0000000000..bd652b5462
--- /dev/null
+++ b/libgo/go/database/sql/ctxutil.go
@@ -0,0 +1,149 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+)
+
+func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
+ if ciCtx, is := ci.(driver.ConnPrepareContext); is {
+ return ciCtx.PrepareContext(ctx, query)
+ }
+ si, err := ci.Prepare(query)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ si.Close()
+ return nil, ctx.Err()
+ }
+ }
+ return si, err
+}
+
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if execerCtx, is := execer.(driver.ExecerContext); is {
+ return execerCtx.ExecContext(ctx, query, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return execer.Exec(query, dargs)
+}
+
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if queryerCtx, is := queryer.(driver.QueryerContext); is {
+ ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
+ return ret, err
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return queryer.Query(query, dargs)
+}
+
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if siCtx, is := si.(driver.StmtExecContext); is {
+ return siCtx.ExecContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return si.Exec(dargs)
+}
+
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if siCtx, is := si.(driver.StmtQueryContext); is {
+ return siCtx.QueryContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return si.Query(dargs)
+}
+
+var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
+
+func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
+ if ciCtx, is := ci.(driver.ConnBeginTx); is {
+ dopts := driver.TxOptions{}
+ if opts != nil {
+ dopts.Isolation = driver.IsolationLevel(opts.Isolation)
+ dopts.ReadOnly = opts.ReadOnly
+ }
+ return ciCtx.BeginTx(ctx, dopts)
+ }
+
+ if ctx.Done() == context.Background().Done() {
+ return ci.Begin()
+ }
+
+ if opts != nil {
+ // Check the transaction level. If the transaction level is non-default
+ // then return an error here as the BeginTx driver value is not supported.
+ if opts.Isolation != LevelDefault {
+ return nil, errors.New("sql: driver does not support non-default isolation level")
+ }
+
+ // If a read-only transaction is requested return an error as the
+ // BeginTx driver value is not supported.
+ if opts.ReadOnly {
+ return nil, errors.New("sql: driver does not support read-only transactions")
+ }
+ }
+
+ txi, err := ci.Begin()
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ txi.Rollback()
+ return nil, ctx.Err()
+ }
+ }
+ return txi, err
+}
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+ dargs := make([]driver.Value, len(named))
+ for n, param := range named {
+ if len(param.Name) > 0 {
+ return nil, errors.New("sql: driver does not support the use of Named Parameters")
+ }
+ dargs[n] = param.Value
+ }
+ return dargs, nil
+}
diff --git a/libgo/go/database/sql/driver/driver.go b/libgo/go/database/sql/driver/driver.go
index eca25f29a0..d66196fd48 100644
--- a/libgo/go/database/sql/driver/driver.go
+++ b/libgo/go/database/sql/driver/driver.go
@@ -8,7 +8,11 @@
// Most code should use package sql.
package driver
-import "errors"
+import (
+ "context"
+ "errors"
+ "reflect"
+)
// Value is a value that drivers must be able to handle.
// It is either nil or an instance of one of these types:
@@ -17,10 +21,25 @@ import "errors"
// float64
// bool
// []byte
-// string [*] everywhere except from Rows.Next.
+// string
// time.Time
type Value interface{}
+// NamedValue holds both the value name and value.
+type NamedValue struct {
+ // If the Name is not empty it should be used for the parameter identifier and
+ // not the ordinal position.
+ //
+ // Name will not have a symbol prefix.
+ Name string
+
+ // Ordinal position of the parameter starting from one and is always set.
+ Ordinal int
+
+ // Value is the parameter value.
+ Value Value
+}
+
// Driver is the interface that must be implemented by a database
// driver.
type Driver interface {
@@ -54,6 +73,17 @@ var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// you shouldn't return ErrBadConn.
var ErrBadConn = errors.New("driver: bad connection")
+// Pinger is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement Pinger, the sql package's DB.Ping and
+// DB.PingContext will check if there is at least one Conn available.
+//
+// If Conn.Ping returns ErrBadConn, DB.Ping and DB.PingContext will remove
+// the Conn from pool.
+type Pinger interface {
+ Ping(ctx context.Context) error
+}
+
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the sql package's DB.Exec will
@@ -61,10 +91,25 @@ var ErrBadConn = errors.New("driver: bad connection")
// statement.
//
// Exec may return ErrSkip.
+//
+// Deprecated: Drivers should implement ExecerContext instead (or additionally).
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
+// ExecerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement ExecerContext, the sql package's DB.Exec will
+// first prepare a query, execute the statement, and then close the
+// statement.
+//
+// ExecerContext may return ErrSkip.
+//
+// ExecerContext must honor the context timeout and return when the context is canceled.
+type ExecerContext interface {
+ ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
+}
+
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Queryer, the sql package's DB.Query will
@@ -72,10 +117,25 @@ type Execer interface {
// statement.
//
// Query may return ErrSkip.
+//
+// Deprecated: Drivers should implement QueryerContext instead (or additionally).
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
+// QueryerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement QueryerContext, the sql package's DB.Query will
+// first prepare a query, execute the statement, and then close the
+// statement.
+//
+// QueryerContext may return ErrSkip.
+//
+// QueryerContext must honor the context timeout and return when the context is canceled.
+type QueryerContext interface {
+ QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
+}
+
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
@@ -95,9 +155,50 @@ type Conn interface {
Close() error
// Begin starts and returns a new transaction.
+ //
+ // Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
Begin() (Tx, error)
}
+// ConnPrepareContext enhances the Conn interface with context.
+type ConnPrepareContext interface {
+ // PrepareContext returns a prepared statement, bound to this connection.
+ // context is for the preparation of the statement,
+ // it must not store the context within the statement itself.
+ PrepareContext(ctx context.Context, query string) (Stmt, error)
+}
+
+// IsolationLevel is the transaction isolation level stored in TxOptions.
+//
+// This type should be considered identical to sql.IsolationLevel along
+// with any values defined on it.
+type IsolationLevel int
+
+// TxOptions holds the transaction options.
+//
+// This type should be considered identical to sql.TxOptions.
+type TxOptions struct {
+ Isolation IsolationLevel
+ ReadOnly bool
+}
+
+// ConnBeginTx enhances the Conn interface with context and TxOptions.
+type ConnBeginTx interface {
+ // BeginTx starts and returns a new transaction.
+ // If the context is canceled by the user the sql package will
+ // call Tx.Rollback before discarding and closing the connection.
+ //
+ // This must check opts.Isolation to determine if there is a set
+ // isolation level. If the driver does not support a non-default
+ // level and one is set or if there is a non-default isolation level
+ // that is not supported, an error must be returned.
+ //
+ // This must also check opts.ReadOnly to determine if the read-only
+ // value is true to either set the read-only transaction property if supported
+ // or return an error if it is not supported.
+ BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
+}
+
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
@@ -132,19 +233,41 @@ type Stmt interface {
// Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
+ //
+ // Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a
// SELECT.
+ //
+ // Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error)
}
+// StmtExecContext enhances the Stmt interface by providing Exec with context.
+type StmtExecContext interface {
+ // ExecContext executes a query that doesn't return rows, such
+ // as an INSERT or UPDATE.
+ //
+ // ExecContext must honor the context timeout and return when it is canceled.
+ ExecContext(ctx context.Context, args []NamedValue) (Result, error)
+}
+
+// StmtQueryContext enhances the Stmt interface by providing Query with context.
+type StmtQueryContext interface {
+ // QueryContext executes a query that may return rows, such as a
+ // SELECT.
+ //
+ // QueryContext must honor the context timeout and return when it is canceled.
+ QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
+}
+
// ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from
// any type to a driver Value.
type ColumnConverter interface {
// ColumnConverter returns a ValueConverter for the provided
- // column index. If the type of a specific column isn't known
+ // column index. If the type of a specific column isn't known
// or shouldn't be handled specially, DefaultValueConverter
// can be returned.
ColumnConverter(idx int) ValueConverter
@@ -154,7 +277,7 @@ type ColumnConverter interface {
type Rows interface {
// Columns returns the names of the columns. The number of
// columns of the result is inferred from the length of the
- // slice. If a particular column name isn't known, an empty
+ // slice. If a particular column name isn't known, an empty
// string should be returned for that entry.
Columns() []string
@@ -165,14 +288,80 @@ type Rows interface {
// the provided slice. The provided slice will be the same
// size as the Columns() are wide.
//
- // The dest slice may be populated only with
- // a driver Value type, but excluding string.
- // All string values must be converted to []byte.
- //
// Next should return io.EOF when there are no more rows.
Next(dest []Value) error
}
+// RowsNextResultSet extends the Rows interface by providing a way to signal
+// the driver to advance to the next result set.
+type RowsNextResultSet interface {
+ Rows
+
+ // HasNextResultSet is called at the end of the current result set and
+ // reports whether there is another result set after the current one.
+ HasNextResultSet() bool
+
+ // NextResultSet advances the driver to the next result set even
+ // if there are remaining rows in the current result set.
+ //
+ // NextResultSet should return io.EOF when there are no more result sets.
+ NextResultSet() error
+}
+
+// RowsColumnTypeScanType may be implemented by Rows. It should return
+// the value type that can be used to scan types into. For example, the database
+// column type "bigint" this should return "reflect.TypeOf(int64(0))".
+type RowsColumnTypeScanType interface {
+ Rows
+ ColumnTypeScanType(index int) reflect.Type
+}
+
+// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
+// database system type name without the length. Type names should be uppercase.
+// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
+// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
+// "TIMESTAMP".
+type RowsColumnTypeDatabaseTypeName interface {
+ Rows
+ ColumnTypeDatabaseTypeName(index int) string
+}
+
+// RowsColumnTypeLength may be implemented by Rows. It should return the length
+// of the column type if the column is a variable length type. If the column is
+// not a variable length type ok should return false.
+// If length is not limited other than system limits, it should return math.MaxInt64.
+// The following are examples of returned values for various types:
+// TEXT (math.MaxInt64, true)
+// varchar(10) (10, true)
+// nvarchar(10) (10, true)
+// decimal (0, false)
+// int (0, false)
+// bytea(30) (30, true)
+type RowsColumnTypeLength interface {
+ Rows
+ ColumnTypeLength(index int) (length int64, ok bool)
+}
+
+// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
+// be true if it is known the column may be null, or false if the column is known
+// to be not nullable.
+// If the column nullability is unknown, ok should be false.
+type RowsColumnTypeNullable interface {
+ Rows
+ ColumnTypeNullable(index int) (nullable, ok bool)
+}
+
+// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
+// the precision and scale for decimal types. If not applicable, ok should be false.
+// The following are examples of returned values for various types:
+// decimal(38, 4) (38, 4, true)
+// int (0, 0, false)
+// decimal (math.MaxInt64, math.MaxInt64, true)
+type RowsColumnTypePrecisionScale interface {
+ Rows
+ ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
+}
+
// Tx is a transaction.
type Tx interface {
Commit() error
diff --git a/libgo/go/database/sql/driver/types.go b/libgo/go/database/sql/driver/types.go
index bc54784989..8b3cb6c8f6 100644
--- a/libgo/go/database/sql/driver/types.go
+++ b/libgo/go/database/sql/driver/types.go
@@ -15,7 +15,7 @@ import (
//
// Various implementations of ValueConverter are provided by the
// driver package to provide consistent implementations of conversions
-// between drivers. The ValueConverters have several uses:
+// between drivers. The ValueConverters have several uses:
//
// * converting from the Value types as provided by the sql package
// into a database table's specific column type and making sure it
@@ -172,28 +172,21 @@ func (n NotNull) ConvertValue(v interface{}) (Value, error) {
}
// IsValue reports whether v is a valid Value parameter type.
-// Unlike IsScanValue, IsValue permits the string type.
func IsValue(v interface{}) bool {
- if IsScanValue(v) {
+ if v == nil {
return true
}
- if _, ok := v.(string); ok {
+ switch v.(type) {
+ case []byte, bool, float64, int64, string, time.Time:
return true
}
return false
}
-// IsScanValue reports whether v is a valid Value scan type.
-// Unlike IsValue, IsScanValue does not permit the string type.
+// IsScanValue is equivalent to IsValue.
+// It exists for compatibility.
func IsScanValue(v interface{}) bool {
- if v == nil {
- return true
- }
- switch v.(type) {
- case int64, float64, []byte, bool, time.Time:
- return true
- }
- return false
+ return IsValue(v)
}
// DefaultParameterConverter is the default implementation of
@@ -205,9 +198,9 @@ func IsScanValue(v interface{}) bool {
// Value method is used to return a Value. As a fallback, the provided
// argument's underlying type is used to convert it to a Value:
// underlying integer types are converted to int64, floats to float64,
-// and strings to []byte. If the argument is a nil pointer,
-// ConvertValue returns a nil Value. If the argument is a non-nil
-// pointer, it is dereferenced and ConvertValue is called
+// bool, string, and []byte to themselves. If the argument is a nil
+// pointer, ConvertValue returns a nil Value. If the argument is a
+// non-nil pointer, it is dereferenced and ConvertValue is called
// recursively. Other types are an error.
var DefaultParameterConverter defaultConverter
@@ -215,13 +208,35 @@ type defaultConverter struct{}
var _ ValueConverter = defaultConverter{}
+var valuerReflectType = reflect.TypeOf((*Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+// Issue 8415.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This function is mirrored in the database/sql package.
+func callValuerValue(vr Valuer) (v Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
+
func (defaultConverter) ConvertValue(v interface{}) (Value, error) {
if IsValue(v) {
return v, nil
}
- if svi, ok := v.(Valuer); ok {
- sv, err := svi.Value()
+ if vr, ok := v.(Valuer); ok {
+ sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
@@ -252,6 +267,16 @@ func (defaultConverter) ConvertValue(v interface{}) (Value, error) {
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
+ case reflect.Bool:
+ return rv.Bool(), nil
+ case reflect.Slice:
+ ek := rv.Type().Elem().Kind()
+ if ek == reflect.Uint8 {
+ return rv.Bytes(), nil
+ }
+ return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
+ case reflect.String:
+ return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
diff --git a/libgo/go/database/sql/driver/types_test.go b/libgo/go/database/sql/driver/types_test.go
index 1ce0ff0654..0379bf8892 100644
--- a/libgo/go/database/sql/driver/types_test.go
+++ b/libgo/go/database/sql/driver/types_test.go
@@ -20,6 +20,16 @@ type valueConverterTest struct {
var now = time.Now()
var answer int64 = 42
+type (
+ i int64
+ f float64
+ b bool
+ bs []byte
+ s string
+ t time.Time
+ is []int
+)
+
var valueConverterTests = []valueConverterTest{
{Bool, "true", true, ""},
{Bool, "True", true, ""},
@@ -41,6 +51,12 @@ var valueConverterTests = []valueConverterTest{
{DefaultParameterConverter, (*int64)(nil), nil, ""},
{DefaultParameterConverter, &answer, answer, ""},
{DefaultParameterConverter, &now, now, ""},
+ {DefaultParameterConverter, i(9), int64(9), ""},
+ {DefaultParameterConverter, f(0.1), float64(0.1), ""},
+ {DefaultParameterConverter, b(true), true, ""},
+ {DefaultParameterConverter, bs{1}, []byte{1}, ""},
+ {DefaultParameterConverter, s("a"), "a", ""},
+ {DefaultParameterConverter, is{1}, nil, "unsupported type driver.is, a slice of int"},
}
func TestValueConverters(t *testing.T) {
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go
index b5ff121358..4b15f5bec7 100644
--- a/libgo/go/database/sql/fakedb_test.go
+++ b/libgo/go/database/sql/fakedb_test.go
@@ -5,11 +5,13 @@
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
+ "reflect"
"sort"
"strconv"
"strings"
@@ -32,12 +34,18 @@ var _ = log.Printf
// where types are: "string", [u]int{8,16,32,64}, "bool"
// INSERT|<tablename>|col=val,col2=val2,col3=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
//
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
+// Any of these can be proceeded by WAIT|<duration>|, to cause the
+// named method on fakeStmt to sleep for the specified duration.
+//
+// Multiple of these can be combined when separated with a semicolon.
+//
// When opening a fakeDriver's database, it starts empty with no
-// tables. All tables and data are stored in memory only.
+// tables. All tables and data are stored in memory only.
type fakeDriver struct {
mu sync.Mutex // guards 3 following fields
openCount int // conn opens
@@ -51,7 +59,6 @@ type fakeDB struct {
name string
mu sync.Mutex
- free []*fakeConn
tables map[string]*table
badConn bool
}
@@ -76,12 +83,6 @@ type row struct {
cols []interface{} // must be same size as its table colname + coltype
}
-func (r *row) clone() *row {
- nrow := &row{cols: make([]interface{}, len(r.cols))}
- copy(nrow.cols, r.cols)
- return nrow
-}
-
type fakeConn struct {
db *fakeDB // where to return ourselves to
@@ -108,6 +109,12 @@ type fakeTx struct {
c *fakeConn
}
+type boundCol struct {
+ Column string
+ Placeholder string
+ Ordinal int
+}
+
type fakeStmt struct {
c *fakeConn
q string // just for debugging
@@ -115,6 +122,9 @@ type fakeStmt struct {
cmd string
table string
panic string
+ wait time.Duration
+
+ next *fakeStmt // used for returning multiple results.
closed bool
@@ -123,7 +133,7 @@ type fakeStmt struct {
colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
placeholders int // used by INSERT/SELECT: number of ? params
- whereCol []string // used by SELECT (all placeholders)
+ whereCol []boundCol // used by SELECT (all placeholders)
placeholderConverter []driver.ValueConverter // used by INSERT
}
@@ -342,18 +352,23 @@ func (c *fakeConn) Close() (err error) {
return nil
}
-func checkSubsetTypes(args []driver.Value) error {
- for n, arg := range args {
- switch arg.(type) {
+func checkSubsetTypes(args []driver.NamedValue) error {
+ for _, arg := range args {
+ switch arg.Value.(type) {
case int64, float64, bool, nil, []byte, string, time.Time:
default:
- return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+ return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
}
}
return nil
}
func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ // Ensure that ExecContext is called if available.
+ panic("ExecContext was not called.")
+}
+
+func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
@@ -366,6 +381,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error
}
func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ // Ensure that ExecContext is called if available.
+ panic("QueryContext was not called.")
+}
+
+func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
@@ -384,12 +404,13 @@ func errf(msg string, args ...interface{}) error {
// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
// (note that where columns must always contain ? marks,
// just a limitation for fakedb)
-func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 3 {
stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
}
stmt.table = parts[0]
+
stmt.colName = strings.Split(parts[1], ",")
for n, colspec := range strings.Split(parts[2], ",") {
if colspec == "" {
@@ -406,19 +427,19 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
}
- if value != "?" {
+ if !strings.HasPrefix(value, "?") {
stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column)
}
- stmt.whereCol = append(stmt.whereCol, column)
stmt.placeholders++
+ stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
}
return stmt, nil
}
// parts are table|col=type,col2=type2
-func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
@@ -437,7 +458,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
// parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@@ -457,7 +478,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
stmt.colName = append(stmt.colName, column)
- if value != "?" {
+ if !strings.HasPrefix(value, "?") {
var subsetVal interface{}
// Convert to driver subset type
switch ctype {
@@ -480,7 +501,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
} else {
stmt.placeholders++
stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
- stmt.colValue = append(stmt.colValue, "?")
+ stmt.colValue = append(stmt.colValue, value)
}
}
return stmt, nil
@@ -490,6 +511,10 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
var hookPrepareBadConn func() bool
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+ panic("use PrepareContext")
+}
+
+func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
c.numPrepare++
if c.db == nil {
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
@@ -499,38 +524,72 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
return nil, driver.ErrBadConn
}
- parts := strings.Split(query, "|")
- if len(parts) < 1 {
- return nil, errf("empty query")
- }
- stmt := &fakeStmt{q: query, c: c}
- if len(parts) >= 3 && parts[0] == "PANIC" {
- stmt.panic = parts[1]
- parts = parts[2:]
- }
- cmd := parts[0]
- stmt.cmd = cmd
- parts = parts[1:]
+ var firstStmt, prev *fakeStmt
+ for _, query := range strings.Split(query, ";") {
+ parts := strings.Split(query, "|")
+ if len(parts) < 1 {
+ return nil, errf("empty query")
+ }
+ stmt := &fakeStmt{q: query, c: c}
+ if firstStmt == nil {
+ firstStmt = stmt
+ }
+ if len(parts) >= 3 {
+ switch parts[0] {
+ case "PANIC":
+ stmt.panic = parts[1]
+ parts = parts[2:]
+ case "WAIT":
+ wait, err := time.ParseDuration(parts[1])
+ if err != nil {
+ return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
+ }
+ parts = parts[2:]
+ stmt.wait = wait
+ }
+ }
+ cmd := parts[0]
+ stmt.cmd = cmd
+ parts = parts[1:]
+
+ if stmt.wait > 0 {
+ wait := time.NewTimer(stmt.wait)
+ select {
+ case <-wait.C:
+ case <-ctx.Done():
+ wait.Stop()
+ return nil, ctx.Err()
+ }
+ }
- c.incrStat(&c.stmtsMade)
- switch cmd {
- case "WIPE":
- // Nothing
- case "SELECT":
- return c.prepareSelect(stmt, parts)
- case "CREATE":
- return c.prepareCreate(stmt, parts)
- case "INSERT":
- return c.prepareInsert(stmt, parts)
- case "NOSERT":
- // Do all the prep-work like for an INSERT but don't actually insert the row.
- // Used for some of the concurrent tests.
- return c.prepareInsert(stmt, parts)
- default:
- stmt.Close()
- return nil, errf("unsupported command type %q", cmd)
+ c.incrStat(&c.stmtsMade)
+ var err error
+ switch cmd {
+ case "WIPE":
+ // Nothing
+ case "SELECT":
+ stmt, err = c.prepareSelect(stmt, parts)
+ case "CREATE":
+ stmt, err = c.prepareCreate(stmt, parts)
+ case "INSERT":
+ stmt, err = c.prepareInsert(stmt, parts)
+ case "NOSERT":
+ // Do all the prep-work like for an INSERT but don't actually insert the row.
+ // Used for some of the concurrent tests.
+ stmt, err = c.prepareInsert(stmt, parts)
+ default:
+ stmt.Close()
+ return nil, errf("unsupported command type %q", cmd)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if prev != nil {
+ prev.next = stmt
+ }
+ prev = stmt
}
- return stmt, nil
+ return firstStmt, nil
}
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
@@ -557,6 +616,9 @@ func (s *fakeStmt) Close() error {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
}
+ if s.next != nil {
+ s.next.Close()
+ }
return nil
}
@@ -566,6 +628,9 @@ var errClosed = errors.New("fakedb: statement has been closed")
var hookExecBadConn func() bool
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
+ panic("Using ExecContext")
+}
+func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s.panic == "Exec" {
panic(s.panic)
}
@@ -582,6 +647,16 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, err
}
+ if s.wait > 0 {
+ time.Sleep(s.wait)
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+
db := s.c.db
switch s.cmd {
case "WIPE":
@@ -606,7 +681,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
// When doInsert is true, add the row to the table.
// When doInsert is false do prep-work and error checking, but don't
// actually add the row to the table.
-func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
+func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
db := s.c.db
if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct")
@@ -632,8 +707,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
}
var val interface{}
- if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
- val = args[argPos]
+ if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
+ if strvalue == "?" {
+ val = args[argPos].Value
+ } else {
+ // Assign value from argument placeholder name.
+ for _, a := range args {
+ if a.Name == strvalue[1:] {
+ val = a.Value
+ break
+ }
+ }
+ }
argPos++
} else {
val = s.colValue[n]
@@ -653,6 +738,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
var hookQueryBadConn func() bool
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
+ panic("Use QueryContext")
+}
+
+func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s.panic == "Query" {
panic(s.panic)
}
@@ -674,65 +763,101 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("error in pkg db; should only get here if size is correct")
}
- db.mu.Lock()
- t, ok := db.table(s.table)
- db.mu.Unlock()
- if !ok {
- return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
- }
+ setMRows := make([][]*row, 0, 1)
+ setColumns := make([][]string, 0, 1)
+ setColType := make([][]string, 0, 1)
- if s.table == "magicquery" {
- if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
- if args[0] == "sleep" {
- time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
- }
+ for {
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
- }
-
- t.mu.Lock()
- defer t.mu.Unlock()
- colIdx := make(map[string]int) // select column name -> column index in table
- for _, name := range s.colName {
- idx := t.columnIndex(name)
- if idx == -1 {
- return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+ if s.table == "magicquery" {
+ if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
+ if args[0].Value == "sleep" {
+ time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
+ }
+ }
}
- colIdx[name] = idx
- }
- mrows := []*row{}
-rows:
- for _, trow := range t.rows {
- // Process the where clause, skipping non-match rows. This is lazy
- // and just uses fmt.Sprintf("%v") to test equality. Good enough
- // for test code.
- for widx, wcol := range s.whereCol {
- idx := t.columnIndex(wcol)
+ t.mu.Lock()
+
+ colIdx := make(map[string]int) // select column name -> column index in table
+ for _, name := range s.colName {
+ idx := t.columnIndex(name)
if idx == -1 {
- return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ t.mu.Unlock()
+ return nil, fmt.Errorf("fakedb: unknown column name %q", name)
}
- tcol := trow.cols[idx]
- if bs, ok := tcol.([]byte); ok {
- // lazy hack to avoid sprintf %v on a []byte
- tcol = string(bs)
+ colIdx[name] = idx
+ }
+
+ mrows := []*row{}
+ rows:
+ for _, trow := range t.rows {
+ // Process the where clause, skipping non-match rows. This is lazy
+ // and just uses fmt.Sprintf("%v") to test equality. Good enough
+ // for test code.
+ for _, wcol := range s.whereCol {
+ idx := t.columnIndex(wcol.Column)
+ if idx == -1 {
+ t.mu.Unlock()
+ return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ }
+ tcol := trow.cols[idx]
+ if bs, ok := tcol.([]byte); ok {
+ // lazy hack to avoid sprintf %v on a []byte
+ tcol = string(bs)
+ }
+ var argValue interface{}
+ if wcol.Placeholder == "?" {
+ argValue = args[wcol.Ordinal-1].Value
+ } else {
+ // Assign arg value from placeholder name.
+ for _, a := range args {
+ if a.Name == wcol.Placeholder[1:] {
+ argValue = a.Value
+ break
+ }
+ }
+ }
+ if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
+ continue rows
+ }
}
- if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
- continue rows
+ mrow := &row{cols: make([]interface{}, len(s.colName))}
+ for seli, name := range s.colName {
+ mrow.cols[seli] = trow.cols[colIdx[name]]
}
+ mrows = append(mrows, mrow)
}
- mrow := &row{cols: make([]interface{}, len(s.colName))}
- for seli, name := range s.colName {
- mrow.cols[seli] = trow.cols[colIdx[name]]
+
+ var colType []string
+ for _, column := range s.colName {
+ colType = append(colType, t.coltype[t.columnIndex(column)])
+ }
+
+ t.mu.Unlock()
+
+ setMRows = append(setMRows, mrows)
+ setColumns = append(setColumns, s.colName)
+ setColType = append(setColType, colType)
+
+ if s.next == nil {
+ break
}
- mrows = append(mrows, mrow)
+ s = s.next
}
cursor := &rowsCursor{
- pos: -1,
- rows: mrows,
- cols: s.colName,
- errPos: -1,
+ posRow: -1,
+ rows: setMRows,
+ cols: setColumns,
+ colType: setColType,
+ errPos: -1,
}
return cursor, nil
}
@@ -767,10 +892,12 @@ func (tx *fakeTx) Rollback() error {
}
type rowsCursor struct {
- cols []string
- pos int
- rows []*row
- closed bool
+ cols [][]string
+ colType [][]string
+ posSet int
+ posRow int
+ rows [][]*row
+ closed bool
// errPos and err are for making Next return early with error.
errPos int
@@ -793,7 +920,11 @@ func (rc *rowsCursor) Close() error {
}
func (rc *rowsCursor) Columns() []string {
- return rc.cols
+ return rc.cols[rc.posSet]
+}
+
+func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
+ return colTypeToReflectType(rc.colType[rc.posSet][index])
}
var rowsCursorNextHook func(dest []driver.Value) error
@@ -806,14 +937,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed {
return errors.New("fakedb: cursor is closed")
}
- rc.pos++
- if rc.pos == rc.errPos {
+ rc.posRow++
+ if rc.posRow == rc.errPos {
return rc.err
}
- if rc.pos >= len(rc.rows) {
+ if rc.posRow >= len(rc.rows[rc.posSet]) {
return io.EOF // per interface spec
}
- for i, v := range rc.rows[rc.pos].cols {
+ for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
// TODO(bradfitz): convert to subset types? naah, I
// think the subset types should only be input to
// driver, but the sql package should be able to handle
@@ -838,6 +969,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
return nil
}
+func (rc *rowsCursor) HasNextResultSet() bool {
+ return rc.posSet < len(rc.rows)-1
+}
+
+func (rc *rowsCursor) NextResultSet() error {
+ if rc.HasNextResultSet() {
+ rc.posSet++
+ rc.posRow = -1
+ return nil
+ }
+ return io.EOF // Per interface spec.
+}
+
// fakeDriverString is like driver.String, but indirects pointers like
// DefaultValueConverter.
//
@@ -889,3 +1033,29 @@ func converterForType(typ string) driver.ValueConverter {
}
panic("invalid fakedb column type of " + typ)
}
+
+func colTypeToReflectType(typ string) reflect.Type {
+ switch typ {
+ case "bool":
+ return reflect.TypeOf(false)
+ case "nullbool":
+ return reflect.TypeOf(NullBool{})
+ case "int32":
+ return reflect.TypeOf(int32(0))
+ case "string":
+ return reflect.TypeOf("")
+ case "nullstring":
+ return reflect.TypeOf(NullString{})
+ case "int64":
+ return reflect.TypeOf(int64(0))
+ case "nullint64":
+ return reflect.TypeOf(NullInt64{})
+ case "float64":
+ return reflect.TypeOf(float64(0))
+ case "nullfloat64":
+ return reflect.TypeOf(NullFloat64{})
+ case "datetime":
+ return reflect.TypeOf(time.Time{})
+ }
+ panic("invalid fakedb column type of " + typ)
+}
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go
index d8e7cb77af..c016681fca 100644
--- a/libgo/go/database/sql/sql.go
+++ b/libgo/go/database/sql/sql.go
@@ -8,15 +8,20 @@
// The sql package must be used in conjunction with a database driver.
// See https://golang.org/s/sqldrivers for a list of drivers.
//
-// For more usage examples, see the wiki page at
+// Drivers that do not support context cancelation will not return until
+// after the query is completed.
+//
+// For usage examples, see the wiki page at
// https://golang.org/s/sqlwiki.
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
"io"
+ "reflect"
"runtime"
"sort"
"sync"
@@ -66,6 +71,75 @@ func Drivers() []string {
return list
}
+// A NamedArg is a named argument. NamedArg values may be used as
+// arguments to Query or Exec and bind to the corresponding named
+// parameter in the SQL statement.
+//
+// For a more concise way to create NamedArg values, see
+// the Named function.
+type NamedArg struct {
+ _Named_Fields_Required struct{}
+
+ // Name is the name of the parameter placeholder.
+ //
+ // If empty, the ordinal position in the argument list will be
+ // used.
+ //
+ // Name must omit any symbol prefix.
+ Name string
+
+ // Value is the value of the parameter.
+ // It may be assigned the same value types as the query
+ // arguments.
+ Value interface{}
+}
+
+// Named provides a more concise way to create NamedArg values.
+//
+// Example usage:
+//
+// db.ExecContext(ctx, `
+// delete from Invoice
+// where
+// TimeCreated < @end
+// and TimeCreated >= @start;`,
+// sql.Named("start", startTime),
+// sql.Named("end", endTime),
+// )
+func Named(name string, value interface{}) NamedArg {
+ // This method exists because the go1compat promise
+ // doesn't guarantee that structs don't grow more fields,
+ // so unkeyed struct literals are a vet error. Thus, we don't
+ // want to allow sql.NamedArg{name, value}.
+ return NamedArg{Name: name, Value: value}
+}
+
+// IsolationLevel is the transaction isolation level used in TxOptions.
+type IsolationLevel int
+
+// Various isolation levels that drivers may support in BeginTx.
+// If a driver does not support a given isolation level an error may be returned.
+//
+// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels.
+const (
+ LevelDefault IsolationLevel = iota
+ LevelReadUncommitted
+ LevelReadCommitted
+ LevelWriteCommitted
+ LevelRepeatableRead
+ LevelSnapshot
+ LevelSerializable
+ LevelLinearizable
+)
+
+// TxOptions holds the transaction options to be used in DB.BeginTx.
+type TxOptions struct {
+ // Isolation is the transaction isolation level.
+ // If zero, the driver or database's default level is used.
+ Isolation IsolationLevel
+ ReadOnly bool
+}
+
// RawBytes is a byte slice that holds a reference to memory owned by
// the database itself. After a Scan into a RawBytes, the slice is only
// valid until the next call to Next, Scan, or Close.
@@ -199,7 +273,7 @@ type Scanner interface {
// time.Time
// nil - for NULL values
//
- // An error should be returned if the value can not be stored
+ // An error should be returned if the value cannot be stored
// without loss of information.
Scan(src interface{}) error
}
@@ -231,8 +305,9 @@ type DB struct {
mu sync.Mutex // protects following fields
freeConn []*driverConn
- connRequests []chan connRequest
- numOpen int // number of opened and pending open connections
+ connRequests map[uint64]chan connRequest
+ nextRequest uint64 // Next key to use in connRequests.
+ numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
@@ -272,7 +347,7 @@ type driverConn struct {
ci driver.Conn
closed bool
finalClosed bool // ci.Close has been called
- openStmt map[driver.Stmt]bool
+ openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
@@ -284,10 +359,10 @@ func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err)
}
-func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
+func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
- delete(dc.openStmt, si)
+ delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
@@ -297,28 +372,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
- si, err := dc.ci.Prepare(query)
- if err == nil {
- // Track each driverConn's open statements, so we can close them
- // before closing the conn.
- //
- // TODO(bradfitz): let drivers opt out of caring about
- // stmt closes if the conn is about to close anyway? For now
- // do the safe thing, in case stmts need to be closed.
- //
- // TODO(bradfitz): after Go 1.2, closing driver.Stmts
- // should be moved to driverStmt, using unique
- // *driverStmts everywhere (including from
- // *Stmt.connStmt, instead of returning a
- // driver.Stmt), using driverStmt as a pointer
- // everywhere, and making it a finalCloser.
- if dc.openStmt == nil {
- dc.openStmt = make(map[driver.Stmt]bool)
- }
- dc.openStmt[si] = true
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
+ si, err := ctxDriverPrepare(ctx, dc.ci, query)
+ if err != nil {
+ return nil, err
}
- return si, err
+
+ // Track each driverConn's open statements, so we can close them
+ // before closing the conn.
+ //
+ // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
+ if dc.openStmt == nil {
+ dc.openStmt = make(map[*driverStmt]bool)
+ }
+ ds := &driverStmt{Locker: dc, si: si}
+ dc.openStmt[ds] = true
+
+ return ds, nil
}
// the dc.db's Mutex is held.
@@ -350,17 +420,26 @@ func (dc *driverConn) Close() error {
}
func (dc *driverConn) finalClose() error {
- dc.Lock()
+ var err error
- for si := range dc.openStmt {
- si.Close()
+ // Each *driverStmt has a lock to the dc. Copy the list out of the dc
+ // before calling close on each stmt.
+ var openStmt []*driverStmt
+ withLock(dc, func() {
+ openStmt = make([]*driverStmt, 0, len(dc.openStmt))
+ for ds := range dc.openStmt {
+ openStmt = append(openStmt, ds)
+ }
+ dc.openStmt = nil
+ })
+ for _, ds := range openStmt {
+ ds.Close()
}
- dc.openStmt = nil
-
- err := dc.ci.Close()
- dc.ci = nil
- dc.finalClosed = true
- dc.Unlock()
+ withLock(dc, func() {
+ dc.finalClosed = true
+ err = dc.ci.Close()
+ dc.ci = nil
+ })
dc.db.mu.Lock()
dc.db.numOpen--
@@ -377,12 +456,21 @@ func (dc *driverConn) finalClose() error {
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
+ closed bool
+ closeErr error // return value of previous Close call
}
+// Close ensures dirver.Stmt is only closed once any always returns the same
+// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
- return ds.si.Close()
+ if ds.closed {
+ return ds.closeErr
+ }
+ ds.closed = true
+ ds.closeErr = ds.si.Close()
+ return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
@@ -485,27 +573,46 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
db := &DB{
- driver: driveri,
- dsn: dataSourceName,
- openerCh: make(chan struct{}, connectionRequestQueueSize),
- lastPut: make(map[*driverConn]string),
+ driver: driveri,
+ dsn: dataSourceName,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ lastPut: make(map[*driverConn]string),
+ connRequests: make(map[uint64]chan connRequest),
}
go db.connectionOpener()
return db, nil
}
-// Ping verifies a connection to the database is still alive,
+// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
-func (db *DB) Ping() error {
- // TODO(bradfitz): give drivers an optional hook to implement
- // this in a more efficient or more reliable way, if they
- // have one.
- dc, err := db.conn(cachedOrNewConn)
+func (db *DB) PingContext(ctx context.Context) error {
+ var dc *driverConn
+ var err error
+
+ for i := 0; i < maxBadConnRetries; i++ {
+ dc, err = db.conn(ctx, cachedOrNewConn)
+ if err != driver.ErrBadConn {
+ break
+ }
+ }
+ if err == driver.ErrBadConn {
+ dc, err = db.conn(ctx, alwaysNewConn)
+ }
if err != nil {
return err
}
- db.putConn(dc, nil)
- return nil
+
+ if pinger, ok := dc.ci.(driver.Pinger); ok {
+ err = pinger.Ping(ctx)
+ }
+ db.putConn(dc, err)
+ return err
+}
+
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+ return db.PingContext(context.Background())
}
// Close closes the database, releasing any open resources.
@@ -718,6 +825,9 @@ func (db *DB) maybeOpenNewConnections() {
for numRequests > 0 {
db.numOpen++ // optimistically
numRequests--
+ if db.closed {
+ return
+ }
db.openerCh <- struct{}{}
}
}
@@ -773,13 +883,28 @@ type connRequest struct {
var errDBClosed = errors.New("sql: database is closed")
+// nextRequestKeyLocked returns the next connection request key.
+// It is assumed that nextRequest will not overflow.
+func (db *DB) nextRequestKeyLocked() uint64 {
+ next := db.nextRequest
+ db.nextRequest++
+ return next
+}
+
// conn returns a newly-opened or cached *driverConn.
-func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
+ // Check if the context is expired.
+ select {
+ default:
+ case <-ctx.Done():
+ db.mu.Unlock()
+ return nil, ctx.Err()
+ }
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
@@ -797,23 +922,42 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
return conn, nil
}
- // Out of free connections or we were asked not to use one. If we're not
+ // Out of free connections or we were asked not to use one. If we're not
// allowed to open any more connections, make a request and wait.
if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
- db.connRequests = append(db.connRequests, req)
+ reqKey := db.nextRequestKeyLocked()
+ db.connRequests[reqKey] = req
db.mu.Unlock()
- ret, ok := <-req
- if !ok {
- return nil, errDBClosed
- }
- if ret.err == nil && ret.conn.expired(lifetime) {
- ret.conn.Close()
- return nil, driver.ErrBadConn
+
+ // Timeout the connection request with the context.
+ select {
+ case <-ctx.Done():
+ // Remove the connection request and ensure no value has been sent
+ // on it after removing.
+ db.mu.Lock()
+ delete(db.connRequests, reqKey)
+ db.mu.Unlock()
+ select {
+ default:
+ case ret, ok := <-req:
+ if ok {
+ db.putConn(ret.conn, ret.err)
+ }
+ }
+ return nil, ctx.Err()
+ case ret, ok := <-req:
+ if !ok {
+ return nil, errDBClosed
+ }
+ if ret.err == nil && ret.conn.expired(lifetime) {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
+ return ret.conn, ret.err
}
- return ret.conn, ret.err
}
db.numOpen++ // optimistically
@@ -838,29 +982,25 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
return dc, nil
}
-var (
- errConnClosed = errors.New("database/sql: internal sentinel error: conn is closed")
- errConnBusy = errors.New("database/sql: internal sentinel error: conn is busy")
-)
-
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
-// noteUnusedDriverStatement notes that si is no longer used and should
+// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
-func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
+func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
- si.Close()
+ ds.Close()
})
} else {
c.Lock()
- defer c.Unlock()
- if !c.finalClosed {
- si.Close()
+ fc := c.finalClosed
+ c.Unlock()
+ if !fc {
+ ds.Close()
}
}
}
@@ -920,16 +1060,19 @@ func (db *DB) putConn(dc *driverConn, err error) {
// If a connRequest was fulfilled or the *driverConn was placed in the
// freeConn list, then true is returned, otherwise false is returned.
func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
+ if db.closed {
+ return false
+ }
if db.maxOpen > 0 && db.numOpen > db.maxOpen {
return false
}
if c := len(db.connRequests); c > 0 {
- req := db.connRequests[0]
- // This copy is O(n) but in practice faster than a linked list.
- // TODO: consider compacting it down less often and
- // moving the base instead?
- copy(db.connRequests, db.connRequests[1:])
- db.connRequests = db.connRequests[:c-1]
+ var req chan connRequest
+ var reqKey uint64
+ for reqKey, req = range db.connRequests {
+ break
+ }
+ delete(db.connRequests, reqKey) // Remove from pending requests.
if err == nil {
dc.inUse = true
}
@@ -951,40 +1094,53 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
// connection to be opened.
const maxBadConnRetries = 2
-// Prepare creates a prepared statement for later queries or executions.
+// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
-func (db *DB) Prepare(query string) (*Stmt, error) {
+//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
var err error
for i := 0; i < maxBadConnRetries; i++ {
- stmt, err = db.prepare(query, cachedOrNewConn)
+ stmt, err = db.prepare(ctx, query, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.prepare(query, alwaysNewConn)
+ return db.prepare(ctx, query, alwaysNewConn)
}
return stmt, err
}
-func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
+// Prepare creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+func (db *DB) Prepare(query string) (*Stmt, error) {
+ return db.PrepareContext(context.Background(), query)
+}
+
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
- dc, err := db.conn(strategy)
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- dc.Lock()
- si, err := dc.prepareLocked(query)
- dc.Unlock()
+ var ds *driverStmt
+ withLock(dc, func() {
+ ds, err = dc.prepareLocked(ctx, query)
+ })
if err != nil {
db.putConn(dc, err)
return nil, err
@@ -992,7 +1148,7 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
stmt := &Stmt{
db: db,
query: query,
- css: []connStmt{{dc, si}},
+ css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed),
}
db.addDep(stmt, stmt)
@@ -1000,25 +1156,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
return stmt, nil
}
-// Exec executes a query without returning any rows.
+// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
-func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
var res Result
var err error
for i := 0; i < maxBadConnRetries; i++ {
- res, err = db.exec(query, args, cachedOrNewConn)
+ res, err = db.exec(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.exec(query, args, alwaysNewConn)
+ return db.exec(ctx, query, args, alwaysNewConn)
}
return res, err
}
-func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
- dc, err := db.conn(strategy)
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+ return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
@@ -1027,13 +1189,15 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
}()
if execer, ok := dc.ci.(driver.Execer); ok {
- dargs, err := driverArgs(nil, args)
+ var dargs []driver.NamedValue
+ dargs, err = driverArgs(nil, args)
if err != nil {
return nil, err
}
- dc.Lock()
- resi, err := execer.Exec(query, dargs)
- dc.Unlock()
+ var resi driver.Result
+ withLock(dc, func() {
+ resi, err = ctxDriverExec(ctx, execer, query, dargs)
+ })
if err != driver.ErrSkip {
if err != nil {
return nil, err
@@ -1042,54 +1206,63 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
}
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
- return resultFromStatement(driverStmt{dc, si}, args...)
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
+ return resultFromStatement(ctx, ds, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
+// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
-func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
var rows *Rows
var err error
for i := 0; i < maxBadConnRetries; i++ {
- rows, err = db.query(query, args, cachedOrNewConn)
+ rows, err = db.query(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.query(query, args, alwaysNewConn)
+ return db.query(ctx, query, args, alwaysNewConn)
}
return rows, err
}
-func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
- ci, err := db.conn(strategy)
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+ return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+ ci, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- return db.queryConn(ci, ci.releaseConn, query, args)
+ return db.queryConn(ctx, ci, ci.releaseConn, query, args)
}
// queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function.
-func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
releaseConn(err)
return nil, err
}
- dc.Lock()
- rowsi, err := queryer.Query(query, dargs)
- dc.Unlock()
+ var rowsi driver.Rows
+ withLock(dc, func() {
+ rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
+ })
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
@@ -1102,24 +1275,25 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
releaseConn: releaseConn,
rowsi: rowsi,
}
+ rows.initContextClose(ctx)
return rows, nil
}
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ var err error
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
releaseConn(err)
return nil, err
}
- ds := driverStmt{dc, si}
- rowsi, err := rowsiFromStatement(ds, args...)
+ ds := &driverStmt{Locker: dc, si: si}
+ rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
- dc.Lock()
- si.Close()
- dc.Unlock()
+ ds.Close()
releaseConn(err)
return nil, err
}
@@ -1130,53 +1304,84 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
- closeStmt: si,
+ closeStmt: ds,
}
+ rows.initContextClose(ctx)
return rows, nil
}
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := db.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
- rows, err := db.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return db.QueryRowContext(context.Background(), query, args...)
}
-// Begin starts a transaction. The isolation level is dependent on
-// the driver.
-func (db *DB) Begin() (*Tx, error) {
+// BeginTx starts a transaction.
+//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginTx is canceled.
+//
+// The provided TxOptions is optional and may be nil if defaults should be used.
+// If a non-default isolation level is used that the driver doesn't support,
+// an error will be returned.
+func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
var tx *Tx
var err error
for i := 0; i < maxBadConnRetries; i++ {
- tx, err = db.begin(cachedOrNewConn)
+ tx, err = db.begin(ctx, opts, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.begin(alwaysNewConn)
+ return db.begin(ctx, opts, alwaysNewConn)
}
return tx, err
}
-func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
- dc, err := db.conn(strategy)
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+func (db *DB) Begin() (*Tx, error) {
+ return db.BeginTx(context.Background(), nil)
+}
+
+func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- dc.Lock()
- txi, err := dc.ci.Begin()
- dc.Unlock()
+ var txi driver.Tx
+ withLock(dc, func() {
+ txi, err = ctxDriverBegin(ctx, opts, dc.ci)
+ })
if err != nil {
db.putConn(dc, err)
return nil, err
}
- return &Tx{
- db: db,
- dc: dc,
- txi: txi,
- }, nil
+
+ // Schedule the transaction to rollback when the context is cancelled.
+ // The cancel function in Tx will be called after done is set to true.
+ ctx, cancel := context.WithCancel(ctx)
+ tx = &Tx{
+ db: db,
+ dc: dc,
+ txi: txi,
+ cancel: cancel,
+ ctx: ctx,
+ }
+ go tx.awaitDone()
+ return tx, nil
}
// Driver returns the database's underlying driver.
@@ -1197,60 +1402,112 @@ func (db *DB) Driver() driver.Driver {
type Tx struct {
db *DB
+ // closemu prevents the transaction from closing while there
+ // is an active query. It is held for read during queries
+ // and exclusively during close.
+ closemu sync.RWMutex
+
// dc is owned exclusively until Commit or Rollback, at which point
// it's returned with putConn.
dc *driverConn
txi driver.Tx
- // done transitions from false to true exactly once, on Commit
+ // done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
- done bool
+ // Use atomic operations on value when checking value.
+ done int32
- // All Stmts prepared for this transaction. These will be closed after the
+ // All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
stmts struct {
sync.Mutex
v []*Stmt
}
+
+ // cancel is called after done transitions from false to true.
+ cancel func()
+
+ // ctx lives for the life of the transaction.
+ ctx context.Context
+}
+
+// awaitDone blocks until the context in Tx is canceled and rolls back
+// the transaction if it's not already done.
+func (tx *Tx) awaitDone() {
+ // Wait for either the transaction to be committed or rolled
+ // back, or for the associated context to be closed.
+ <-tx.ctx.Done()
+
+ // Discard and close the connection used to ensure the
+ // transaction is closed and the resources are released. This
+ // rollback does nothing if the transaction has already been
+ // committed or rolled back.
+ tx.rollback(true)
+}
+
+func (tx *Tx) isDone() bool {
+ return atomic.LoadInt32(&tx.done) != 0
}
+// ErrTxDone is returned by any operation that is performed on a transaction
+// that has already been committed or rolled back.
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
+// close returns the connection to the pool and
+// must only be called by Tx.rollback or Tx.Commit.
func (tx *Tx) close(err error) {
- if tx.done {
- panic("double close") // internal error
- }
- tx.done = true
+ tx.closemu.Lock()
+ defer tx.closemu.Unlock()
+
tx.db.putConn(tx.dc, err)
+ tx.cancel()
tx.dc = nil
tx.txi = nil
}
-func (tx *Tx) grabConn() (*driverConn, error) {
- if tx.done {
+// hookTxGrabConn specifies an optional hook to be called on
+// a successful call to (*Tx).grabConn. For tests.
+var hookTxGrabConn func()
+
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ if tx.isDone() {
return nil, ErrTxDone
}
+ if hookTxGrabConn != nil { // test hook
+ hookTxGrabConn()
+ }
return tx.dc, nil
}
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
+ defer tx.stmts.Unlock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
- tx.stmts.Unlock()
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if tx.done {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTxDone
}
- tx.dc.Lock()
- err := tx.txi.Commit()
- tx.dc.Unlock()
+ select {
+ default:
+ case <-tx.ctx.Done():
+ return tx.ctx.Err()
+ }
+ var err error
+ withLock(tx.dc, func() {
+ err = tx.txi.Commit()
+ })
if err != driver.ErrBadConn {
tx.closePrepared()
}
@@ -1258,49 +1515,67 @@ func (tx *Tx) Commit() error {
return err
}
-// Rollback aborts the transaction.
-func (tx *Tx) Rollback() error {
- if tx.done {
+// rollback aborts the transaction and optionally forces the pool to discard
+// the connection.
+func (tx *Tx) rollback(discardConn bool) error {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
return ErrTxDone
}
- tx.dc.Lock()
- err := tx.txi.Rollback()
- tx.dc.Unlock()
+ var err error
+ withLock(tx.dc, func() {
+ err = tx.txi.Rollback()
+ })
if err != driver.ErrBadConn {
tx.closePrepared()
}
+ if discardConn {
+ err = driver.ErrBadConn
+ }
tx.close(err)
return err
}
+// Rollback aborts the transaction.
+func (tx *Tx) Rollback() error {
+ return tx.rollback(false)
+}
+
// Prepare creates a prepared statement for use within a transaction.
//
-// The returned statement operates within the transaction and can no longer
-// be used once the transaction has been committed or rolled back.
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
-func (tx *Tx) Prepare(query string) (*Stmt, error) {
+//
+// The provided context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+ tx.closemu.RLock()
+ defer tx.closemu.RUnlock()
+
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
// necessary. Or, better: keep a map in DB of query string to
// Stmts, and have Stmt.Execute do the right thing and
// re-prepare if the Conn in use doesn't have that prepared
- // statement. But we'll want to avoid caching the statement
+ // statement. But we'll want to avoid caching the statement
// in the case where we only call conn.Prepare implicitly
// (such as in db.Exec or tx.Exec), but the caller package
// can't be holding a reference to the returned statement.
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
return nil, err
}
@@ -1308,7 +1583,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
stmt := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1320,7 +1595,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return stmt, nil
}
-// Stmt returns a transaction-specific prepared statement from
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+ return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
@@ -1328,30 +1613,34 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// ...
// tx, err := db.Begin()
// ...
-// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
-// The returned statement operates within the transaction and can no longer
-// be used once the transaction has been committed or rolled back.
-func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
+ tx.closemu.RLock()
+ defer tx.closemu.RUnlock()
+
// TODO(bradfitz): optimize this. Currently this re-prepares
- // each time. This is fine for now to illustrate the API but
+ // each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
// per-Conn. See also the big comment in Tx.Prepare.
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
- dc.Lock()
- si, err := dc.ci.Prepare(stmt.query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
+ })
txs := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1364,10 +1653,29 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return txs
}
-// Exec executes a query that doesn't return rows.
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+ return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
- dc, err := tx.grabConn()
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+ tx.closemu.RLock()
+ defer tx.closemu.RUnlock()
+
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
@@ -1377,9 +1685,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
if err != nil {
return nil, err
}
- dc.Lock()
- resi, err := execer.Exec(query, dargs)
- dc.Unlock()
+ var resi driver.Result
+ withLock(dc, func() {
+ resi, err = ctxDriverExec(ctx, execer, query, dargs)
+ })
if err == nil {
return driverResult{dc, resi}, nil
}
@@ -1388,39 +1697,62 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
}
}
- dc.Lock()
- si, err := dc.ci.Prepare(query)
- dc.Unlock()
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
- return resultFromStatement(driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, ds, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
-func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
- dc, err := tx.grabConn()
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+ return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+ tx.closemu.RLock()
+ defer tx.closemu.RUnlock()
+
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
releaseConn := func(error) {}
- return tx.db.queryConn(dc, releaseConn, query, args)
+ return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
+ return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := tx.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
- rows, err := tx.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
- si driver.Stmt
+ ds *driverStmt
}
// Stmt is a prepared statement.
@@ -1435,15 +1767,15 @@ type Stmt struct {
// If in a transaction, else both nil:
tx *Tx
- txsi *driverStmt
+ txds *driverStmt
mu sync.Mutex // protects the rest of the fields
closed bool
// css is a list of underlying driver statement interfaces
- // that are valid on particular connections. This is only
+ // that are valid on particular connections. This is only
// used if tx == nil and one is found that has idle
- // connections. If tx != nil, txsi is always used.
+ // connections. If tx != nil, txsi is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
@@ -1451,15 +1783,15 @@ type Stmt struct {
lastNumClosed uint64
}
-// Exec executes a prepared statement with the given arguments and
+// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
-func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ _, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1467,7 +1799,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, err
}
- res, err = resultFromStatement(driverStmt{dc, si}, args...)
+ res, err = resultFromStatement(ctx, ds, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
@@ -1476,13 +1808,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, driver.ErrBadConn
}
-func driverNumInput(ds driverStmt) int {
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+ return s.ExecContext(context.Background(), args...)
+}
+
+func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
-func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
@@ -1492,14 +1830,15 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
ds.Lock()
defer ds.Unlock()
- resi, err := ds.si.Exec(dargs)
+
+ resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
@@ -1535,7 +1874,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -1550,19 +1889,18 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
// transaction was created on.
if s.tx != nil {
s.mu.Unlock()
- ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+ ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
releaseConn = func(error) {}
- return ci, releaseConn, s.txsi.si, nil
+ return ci, releaseConn, s.txds, nil
}
s.removeClosedStmtLocked()
s.mu.Unlock()
- // TODO(bradfitz): or always wait for one? make configurable later?
- dc, err := s.db.conn(cachedOrNewConn)
+ dc, err := s.db.conn(ctx, cachedOrNewConn)
if err != nil {
return nil, nil, nil, err
}
@@ -1571,36 +1909,36 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
- return dc, dc.releaseConn, v.si, nil
+ return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
- dc.Lock()
- si, err = dc.prepareLocked(s.query)
- dc.Unlock()
+ withLock(dc, func() {
+ ds, err = dc.prepareLocked(ctx, s.query)
+ })
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
s.mu.Lock()
- cs := connStmt{dc, si}
+ cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
- return dc, dc.releaseConn, si, nil
+ return dc, dc.releaseConn, ds, nil
}
-// Query executes a prepared query statement with the given arguments
+// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
-func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1608,7 +1946,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return nil, err
}
- rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+ rowsi, err = rowsiFromStatement(ctx, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
@@ -1617,6 +1955,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
rowsi: rowsi,
// releaseConn set below
}
+ rows.initContextClose(ctx)
s.db.addDep(s, rows)
rows.releaseConn = func(err error) {
releaseConn(err)
@@ -1633,10 +1972,17 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return nil, driver.ErrBadConn
}
-func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
- ds.Lock()
- want := ds.si.NumInput()
- ds.Unlock()
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+ return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
+ var want int
+ withLock(ds, func() {
+ want = ds.si.NumInput()
+ })
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
@@ -1645,21 +1991,22 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
ds.Lock()
- rowsi, err := ds.si.Query(dargs)
- ds.Unlock()
+ defer ds.Unlock()
+
+ rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return rowsi, nil
}
-// QueryRow executes a prepared query statement with the given arguments.
+// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
@@ -1669,15 +2016,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
// Example usage:
//
// var name string
-// err := nameByUseridStmt.QueryRow(id).Scan(&name)
-func (s *Stmt) QueryRow(args ...interface{}) *Row {
- rows, err := s.Query(args...)
+// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
+ rows, err := s.QueryContext(ctx, args...)
if err != nil {
return &Row{err: err}
}
return &Row{rows: rows}
}
+// QueryRow executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+// var name string
+// err := nameByUseridStmt.QueryRow(id).Scan(&name)
+func (s *Stmt) QueryRow(args ...interface{}) *Row {
+ return s.QueryRowContext(context.Background(), args...)
+}
+
// Close closes the statement.
func (s *Stmt) Close() error {
s.closemu.Lock()
@@ -1692,13 +2054,11 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
+ s.mu.Unlock()
if s.tx != nil {
- err := s.txsi.Close()
- s.mu.Unlock()
- return err
+ return s.txds.Close()
}
- s.mu.Unlock()
return s.db.removeDep(s, s)
}
@@ -1708,8 +2068,8 @@ func (s *Stmt) finalClose() error {
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
- s.db.noteUnusedDriverStatement(v.dc, v.si)
- v.dc.removeOpenStmt(v.si)
+ s.db.noteUnusedDriverStatement(v.dc, v.ds)
+ v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
@@ -1734,29 +2094,110 @@ type Rows struct {
dc *driverConn // owned; must call releaseConn when closed to release
releaseConn func(error)
rowsi driver.Rows
+ cancel func() // called when Rows is closed, may be nil.
+ closeStmt *driverStmt // if non-nil, statement to Close on close
+
+ // closemu prevents Rows from closing while there
+ // is an active streaming result. It is held for read during non-close operations
+ // and exclusively during close.
+ //
+ // closemu guards lasterr and closed.
+ closemu sync.RWMutex
+ closed bool
+ lasterr error // non-nil only if closed is true
- closed bool
- lastcols []driver.Value
- lasterr error // non-nil only if closed is true
- closeStmt driver.Stmt // if non-nil, statement to Close on close
+ // lastcols is only used in Scan, Next, and NextResultSet which are expected
+ // not not be called concurrently.
+ lastcols []driver.Value
}
-// Next prepares the next result row for reading with the Scan method. It
+func (rs *Rows) initContextClose(ctx context.Context) {
+ ctx, rs.cancel = context.WithCancel(ctx)
+ go rs.awaitDone(ctx)
+}
+
+// awaitDone blocks until the rows are closed or the context canceled.
+func (rs *Rows) awaitDone(ctx context.Context) {
+ <-ctx.Done()
+ rs.close(ctx.Err())
+}
+
+// Next prepares the next result row for reading with the Scan method. It
// returns true on success, or false if there is no next result row or an error
-// happened while preparing it. Err should be consulted to distinguish between
+// happened while preparing it. Err should be consulted to distinguish between
// the two cases.
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *Rows) Next() bool {
+ var doClose, ok bool
+ withLock(rs.closemu.RLocker(), func() {
+ doClose, ok = rs.nextLocked()
+ })
+ if doClose {
+ rs.Close()
+ }
+ return ok
+}
+
+func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
- return false
+ return false, false
}
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
- rs.Close()
+ // Close the connection if there is a driver error.
+ if rs.lasterr != io.EOF {
+ return true, false
+ }
+ nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+ if !ok {
+ return true, false
+ }
+ // The driver is at the end of the current result set.
+ // Test to see if there is another result set after the current one.
+ // Only close Rows if there is no further result sets to read.
+ if !nextResultSet.HasNextResultSet() {
+ doClose = true
+ }
+ return doClose, false
+ }
+ return false, true
+}
+
+// NextResultSet prepares the next result set for reading. It returns true if
+// there is further result sets, or false if there is no further result set
+// or if there is an error advancing to it. The Err method should be consulted
+// to distinguish between the two cases.
+//
+// After calling NextResultSet, the Next method should always be called before
+// scanning. If there are further result sets they may not have rows in the result
+// set.
+func (rs *Rows) NextResultSet() bool {
+ var doClose bool
+ defer func() {
+ if doClose {
+ rs.Close()
+ }
+ }()
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
+
+ if rs.closed {
+ return false
+ }
+
+ rs.lastcols = nil
+ nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+ if !ok {
+ doClose = true
+ return false
+ }
+ rs.lasterr = nextResultSet.NextResultSet()
+ if rs.lasterr != nil {
+ doClose = true
return false
}
return true
@@ -1765,6 +2206,8 @@ func (rs *Rows) Next() bool {
// Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit Close.
func (rs *Rows) Err() error {
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
if rs.lasterr == io.EOF {
return nil
}
@@ -1775,6 +2218,8 @@ func (rs *Rows) Err() error {
// Columns returns an error if the rows are closed, or if the rows
// are from QueryRow and there was a deferred error.
func (rs *Rows) Columns() ([]string, error) {
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
if rs.closed {
return nil, errors.New("sql: Rows are closed")
}
@@ -1784,6 +2229,109 @@ func (rs *Rows) Columns() ([]string, error) {
return rs.rowsi.Columns(), nil
}
+// ColumnTypes returns column information such as column type, length,
+// and nullable. Some information may not be available from some drivers.
+func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
+ if rs.closed {
+ return nil, errors.New("sql: Rows are closed")
+ }
+ if rs.rowsi == nil {
+ return nil, errors.New("sql: no Rows available")
+ }
+ return rowsColumnInfoSetup(rs.rowsi), nil
+}
+
+// ColumnType contains the name and type of a column.
+type ColumnType struct {
+ name string
+
+ hasNullable bool
+ hasLength bool
+ hasPrecisionScale bool
+
+ nullable bool
+ length int64
+ databaseType string
+ precision int64
+ scale int64
+ scanType reflect.Type
+}
+
+// Name returns the name or alias of the column.
+func (ci *ColumnType) Name() string {
+ return ci.name
+}
+
+// Length returns the column type length for variable length column types such
+// as text and binary field types. If the type length is unbounded the value will
+// be math.MaxInt64 (any database limits will still apply).
+// If the column type is not variable length, such as an int, or if not supported
+// by the driver ok is false.
+func (ci *ColumnType) Length() (length int64, ok bool) {
+ return ci.length, ci.hasLength
+}
+
+// DecimalSize returns the scale and precision of a decimal type.
+// If not applicable or if not supported ok is false.
+func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
+ return ci.precision, ci.scale, ci.hasPrecisionScale
+}
+
+// ScanType returns a Go type suitable for scanning into using Rows.Scan.
+// If a driver does not support this property ScanType will return
+// the type of an empty interface.
+func (ci *ColumnType) ScanType() reflect.Type {
+ return ci.scanType
+}
+
+// Nullable returns whether the column may be null.
+// If a driver does not support this property ok will be false.
+func (ci *ColumnType) Nullable() (nullable, ok bool) {
+ return ci.nullable, ci.hasNullable
+}
+
+// DatabaseTypeName returns the database system name of the column type. If an empty
+// string is returned the driver type name is not supported.
+// Consult your driver documentation for a list of driver data types. Length specifiers
+// are not included.
+// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT".
+func (ci *ColumnType) DatabaseTypeName() string {
+ return ci.databaseType
+}
+
+func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+ names := rowsi.Columns()
+
+ list := make([]*ColumnType, len(names))
+ for i := range list {
+ ci := &ColumnType{
+ name: names[i],
+ }
+ list[i] = ci
+
+ if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
+ ci.scanType = prop.ColumnTypeScanType(i)
+ } else {
+ ci.scanType = reflect.TypeOf(new(interface{})).Elem()
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
+ ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
+ ci.length, ci.hasLength = prop.ColumnTypeLength(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
+ ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
+ ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
+ }
+ }
+ return list
+}
+
// Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the
// number of columns in Rows.
@@ -1836,9 +2384,13 @@ func (rs *Rows) Columns() ([]string, error) {
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
func (rs *Rows) Scan(dest ...interface{}) error {
+ rs.closemu.RLock()
if rs.closed {
+ rs.closemu.RUnlock()
return errors.New("sql: Rows are closed")
}
+ rs.closemu.RUnlock()
+
if rs.lastcols == nil {
return errors.New("sql: Scan called without calling Next")
}
@@ -1854,20 +2406,39 @@ func (rs *Rows) Scan(dest ...interface{}) error {
return nil
}
-var rowsCloseHook func(*Rows, *error)
+// rowsCloseHook returns a function so tests may install the
+// hook throug a test only mutex.
+var rowsCloseHook = func() func(*Rows, *error) { return nil }
-// Close closes the Rows, preventing further enumeration. If Next returns
-// false, the Rows are closed automatically and it will suffice to check the
+// Close closes the Rows, preventing further enumeration. If Next is called
+// and returns false and there are no further result sets,
+// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error {
+ return rs.close(nil)
+}
+
+func (rs *Rows) close(err error) error {
+ rs.closemu.Lock()
+ defer rs.closemu.Unlock()
+
if rs.closed {
return nil
}
rs.closed = true
- err := rs.rowsi.Close()
- if fn := rowsCloseHook; fn != nil {
+
+ if rs.lasterr == nil {
+ rs.lasterr = err
+ }
+
+ err = rs.rowsi.Close()
+ if fn := rowsCloseHook(); fn != nil {
fn(rs, &err)
}
+ if rs.cancel != nil {
+ rs.cancel()
+ }
+
if rs.closeStmt != nil {
rs.closeStmt.Close()
}
@@ -1898,8 +2469,8 @@ func (r *Row) Scan(dest ...interface{}) error {
// the Rows in our defer, when we return from this function.
// the contract with the driver.Next(...) interface is that it
// can return slices into read-only temporary memory that's
- // only valid until the next Scan/Close. But the TODO is that
- // for a lot of drivers, this copy will be unnecessary. We
+ // only valid until the next Scan/Close. But the TODO is that
+ // for a lot of drivers, this copy will be unnecessary. We
// should provide an optional interface for drivers to
// implement to say, "don't worry, the []bytes that I return
// from Next will not be modified again." (for instance, if
diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go
index 8ec70d99b0..450e5f1f8c 100644
--- a/libgo/go/database/sql/sql_test.go
+++ b/libgo/go/database/sql/sql_test.go
@@ -5,6 +5,7 @@
package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
@@ -13,6 +14,7 @@ import (
"runtime"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
)
@@ -23,6 +25,17 @@ func init() {
c *driverConn
}
freedFrom := make(map[dbConn]string)
+ var mu sync.Mutex
+ getFreedFrom := func(c dbConn) string {
+ mu.Lock()
+ defer mu.Unlock()
+ return freedFrom[c]
+ }
+ setFreedFrom := func(c dbConn, s string) {
+ mu.Lock()
+ defer mu.Unlock()
+ freedFrom[c] = s
+ }
putConnHook = func(db *DB, c *driverConn) {
idx := -1
for i, v := range db.freeConn {
@@ -35,10 +48,10 @@ func init() {
// print before panic, as panic may get lost due to conflicting panic
// (all goroutines asleep) elsewhere, since we might not unlock
// the mutex in freeConn here.
- println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
+ println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack())
panic("double free of conn.")
}
- freedFrom[dbConn{db, c}] = stack()
+ setFreedFrom(dbConn{db, c}, stack())
}
}
@@ -140,11 +153,13 @@ func closeDB(t testing.TB, db *DB) {
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
- db.mu.Lock()
- count := db.numOpen
- db.mu.Unlock()
- if count != 0 {
- t.Fatalf("%d connections still open after closing DB", db.numOpen)
+
+ var numOpen int
+ if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+ numOpen = db.numOpenConns()
+ return numOpen == 0
+ }) {
+ t.Fatalf("%d connections still open after closing DB", numOpen)
}
}
@@ -182,6 +197,12 @@ func (db *DB) numFreeConns() int {
return len(db.freeConn)
}
+func (db *DB) numOpenConns() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return db.numOpen
+}
+
// clearAllConns closes all connections in db.
func (db *DB) clearAllConns(t *testing.T) {
db.SetMaxIdleConns(0)
@@ -260,6 +281,338 @@ func TestQuery(t *testing.T) {
}
}
+// TestQueryContext tests canceling the context while scanning the rows.
+func TestQueryContext(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ index := 0
+ for rows.Next() {
+ if index == 2 {
+ cancel()
+ waitForRowsClose(t, rows, 5*time.Second)
+ }
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ if index == 2 {
+ break
+ }
+ t.Fatalf("Scan: %v", err)
+ }
+ if index == 2 && err == nil {
+ t.Fatal("expected an error on last scan")
+ }
+ got = append(got, r)
+ index++
+ }
+ select {
+ case <-ctx.Done():
+ if err := ctx.Err(); err != context.Canceled {
+ t.Fatalf("context err = %v; want context.Canceled")
+ }
+ default:
+ t.Fatalf("context err = nil; want context.Canceled")
+ }
+ want := []row{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ waitForRowsClose(t, rows, 5*time.Second)
+ waitForFree(t, db, 5*time.Second, 1)
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool {
+ deadline := time.Now().Add(waitFor)
+ for time.Now().Before(deadline) {
+ if fn() {
+ return true
+ }
+ time.Sleep(checkEvery)
+ }
+ return false
+}
+
+// waitForFree checks db.numFreeConns until either it equals want or
+// the maxWait time elapses.
+func waitForFree(t *testing.T, db *DB, maxWait time.Duration, want int) {
+ var numFree int
+ if !waitCondition(maxWait, 5*time.Millisecond, func() bool {
+ numFree = db.numFreeConns()
+ return numFree == want
+ }) {
+ t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want)
+ }
+}
+
+func waitForRowsClose(t *testing.T, rows *Rows, maxWait time.Duration) {
+ if !waitCondition(maxWait, 5*time.Millisecond, func() bool {
+ rows.closemu.RLock()
+ defer rows.closemu.RUnlock()
+ return rows.closed
+ }) {
+ t.Fatal("failed to close rows")
+ }
+}
+
+// TestQueryContextWait ensures that rows and all internal statements are closed when
+// a query context is closed during execution.
+func TestQueryContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ // TODO(kardianos): convert this from using a timeout to using an explicit
+ // cancel when the query signals that is is "executing" the query.
+ ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
+ defer cancel()
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err := db.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ // Verify closed rows connection after error condition.
+ waitForFree(t, db, 5*time.Second, 1)
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ // TODO(kardianos): if the context timeouts before the db.QueryContext
+ // executes this check may fail. After adjusting how the context
+ // is canceled above revert this back to a Fatal error.
+ t.Logf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+// TestTxContextWait tests the transaction behavior when the tx context is canceled
+// during execution of the query.
+func TestTxContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ // Guard against the context being canceled before BeginTx completes.
+ if err == context.DeadlineExceeded {
+ t.Skip("tx context canceled prior to first use")
+ }
+ t.Fatal(err)
+ }
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err = tx.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ waitForFree(t, db, 5*time.Second, 0)
+}
+
+func TestMultiResultSetQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row1 struct {
+ age int
+ name string
+ }
+ type row2 struct {
+ name string
+ }
+ got1 := []row1{}
+ for rows.Next() {
+ var r row1
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got1 = append(got1, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want1 := []row1{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ {age: 3, name: "Chris"},
+ }
+ if !reflect.DeepEqual(got1, want1) {
+ t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
+ }
+
+ if !rows.NextResultSet() {
+ t.Errorf("expected another result set")
+ }
+
+ got2 := []row2{}
+ for rows.Next() {
+ var r row2
+ err = rows.Scan(&r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got2 = append(got2, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want2 := []row2{
+ {name: "Alice"},
+ {name: "Bob"},
+ {name: "Chris"},
+ }
+ if !reflect.DeepEqual(got2, want2) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
+ }
+ if rows.NextResultSet() {
+ t.Errorf("expected no more result sets")
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ waitForFree(t, db, 5*time.Second, 1)
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestQueryNamedArg(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query(
+ // Ensure the name and age parameters only match on placeholder name, not position.
+ "SELECT|people|age,name|name=?name,age=?age",
+ Named("age", 2),
+ Named("name", "Bob"),
+ )
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ for rows.Next() {
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got = append(got, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestPoolExhaustOnCancel(t *testing.T) {
+ if testing.Short() {
+ t.Skip("long test")
+ }
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ max := 3
+
+ db.SetMaxOpenConns(max)
+
+ // First saturate the connection pool.
+ // Then start new requests for a connection that is cancelled after it is requested.
+
+ var saturate, saturateDone sync.WaitGroup
+ saturate.Add(max)
+ saturateDone.Add(max)
+
+ for i := 0; i < max; i++ {
+ go func() {
+ saturate.Done()
+ rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ rows.Close()
+ saturateDone.Done()
+ }()
+ }
+
+ saturate.Wait()
+
+ // Now cancel the request while it is waiting.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
+ defer cancel()
+
+ for i := 0; i < max; i++ {
+ ctxReq, cancelReq := context.WithCancel(ctx)
+ go func() {
+ time.Sleep(time.Millisecond * 100)
+ cancelReq()
+ }()
+ err := db.PingContext(ctxReq)
+ if err != context.Canceled {
+ t.Fatalf("PingContext (Exhaust): %v", err)
+ }
+ }
+
+ saturateDone.Wait()
+
+ // Now try to open a normal connection.
+ err := db.PingContext(ctx)
+ if err != nil {
+ t.Fatalf("PingContext (Normal): %v", err)
+ }
+}
+
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -317,6 +670,56 @@ func TestRowsColumns(t *testing.T) {
}
}
+func TestRowsColumnTypes(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ tt, err := rows.ColumnTypes()
+ if err != nil {
+ t.Fatalf("ColumnTypes: %v", err)
+ }
+
+ types := make([]reflect.Type, len(tt))
+ for i, tp := range tt {
+ st := tp.ScanType()
+ if st == nil {
+ t.Errorf("scantype is null for column %q", tp.Name())
+ continue
+ }
+ types[i] = st
+ }
+ values := make([]interface{}, len(tt))
+ for i := range values {
+ values[i] = reflect.New(types[i]).Interface()
+ }
+ ct := 0
+ for rows.Next() {
+ err = rows.Scan(values...)
+ if err != nil {
+ t.Fatalf("failed to scan values in %v", err)
+ }
+ ct++
+ if ct == 0 {
+ if values[0].(string) != "Bob" {
+ t.Errorf("Expected Bob, got %v", values[0])
+ }
+ if values[1].(int) != 2 {
+ t.Errorf("Expected 2, got %v", values[1])
+ }
+ }
+ }
+ if ct != 3 {
+ t.Errorf("expected 3 rows, got %d", ct)
+ }
+
+ if err := rows.Close(); err != nil {
+ t.Errorf("error closing rows: %s", err)
+ }
+}
+
func TestQueryRow(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -367,6 +770,37 @@ func TestQueryRow(t *testing.T) {
}
}
+func TestTxRollbackCommitErr(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Errorf("expected nil error from Rollback; got %v", err)
+ }
+ err = tx.Commit()
+ if err != ErrTxDone {
+ t.Errorf("expected %q from Commit; got %q", ErrTxDone, err)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Errorf("expected nil error from Commit; got %v", err)
+ }
+ err = tx.Rollback()
+ if err != ErrTxDone {
+ t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err)
+ }
+}
+
func TestStatementErrorAfterClose(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -439,7 +873,7 @@ func TestStatementClose(t *testing.T) {
msg string
}{
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
- {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+ {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
}
for _, test := range tests {
if err := test.stmt.Close(); err != want {
@@ -513,8 +947,8 @@ func TestExec(t *testing.T) {
{[]interface{}{7, 9}, ""},
// Invalid conversions:
- {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"},
- {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"},
+ {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"},
+ {[]interface{}{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`},
// Wrong number of args:
{[]interface{}{}, "sql: expected 2 arguments, got 0"},
@@ -788,6 +1222,24 @@ func TestQueryRowClosingStmt(t *testing.T) {
}
}
+var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
+
+func init() {
+ rowsCloseHook = func() func(*Rows, *error) {
+ fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
+ return fn
+ }
+}
+
+func setRowsCloseHook(fn func(*Rows, *error)) {
+ if fn == nil {
+ // Can't change an atomic.Value back to nil, so set it to this
+ // no-op func instead.
+ fn = func(*Rows, *error) {}
+ }
+ atomicRowsCloseHook.Store(fn)
+}
+
// Test issue 6651
func TestIssue6651(t *testing.T) {
db := newTestDB(t, "people")
@@ -800,6 +1252,7 @@ func TestIssue6651(t *testing.T) {
return fmt.Errorf(want)
}
defer func() { rowsCursorNextHook = nil }()
+
err := db.QueryRow("SELECT|people|name|").Scan(&v)
if err == nil || err.Error() != want {
t.Errorf("error = %q; want %q", err, want)
@@ -807,10 +1260,10 @@ func TestIssue6651(t *testing.T) {
rowsCursorNextHook = nil
want = "error in rows.Close"
- rowsCloseHook = func(rows *Rows, err *error) {
+ setRowsCloseHook(func(rows *Rows, err *error) {
*err = fmt.Errorf(want)
- }
- defer func() { rowsCloseHook = nil }()
+ })
+ defer setRowsCloseHook(nil)
err = db.QueryRow("SELECT|people|name|").Scan(&v)
if err == nil || err.Error() != want {
t.Errorf("error = %q; want %q", err, want)
@@ -911,7 +1364,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
if err == nil {
// TODO: this test fails, but it's just because
// fakeConn implements the optional Execer interface,
- // so arguably this is the correct behavior. But
+ // so arguably this is the correct behavior. But
// maybe I should flesh out the fakeConn.Exec
// implementation so this properly fails.
// t.Errorf("expected error inserting nil name with Exec")
@@ -1159,17 +1612,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) {
db.SetMaxOpenConns(3)
- conn0, err := db.conn(cachedOrNewConn)
+ ctx := context.Background()
+
+ conn0, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
- conn1, err := db.conn(cachedOrNewConn)
+ conn1, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
- conn2, err := db.conn(cachedOrNewConn)
+ conn2, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
t.Fatalf("db open conn fail: %v", err)
}
@@ -1203,7 +1658,11 @@ func TestPendingConnsAfterErr(t *testing.T) {
tryOpen = maxOpen*2 + 2
)
- db := newTestDB(t, "people")
+ // No queries will be run.
+ db, err := Open("test", fakeDBName)
+ if err != nil {
+ t.Fatalf("Open: %v", err)
+ }
defer closeDB(t, db)
defer func() {
for k, v := range db.lastPut {
@@ -1215,31 +1674,31 @@ func TestPendingConnsAfterErr(t *testing.T) {
db.SetMaxIdleConns(0)
errOffline := errors.New("db offline")
+
defer func() { setHookOpenErr(nil) }()
errs := make(chan error, tryOpen)
- unblock := make(chan struct{})
+ var opening sync.WaitGroup
+ opening.Add(tryOpen)
+
setHookOpenErr(func() error {
- <-unblock // block until all connections are in flight
+ // Wait for all connections to enqueue.
+ opening.Wait()
return errOffline
})
- var opening sync.WaitGroup
- opening.Add(tryOpen)
for i := 0; i < tryOpen; i++ {
go func() {
opening.Done() // signal one connection is in flight
- _, err := db.Exec("INSERT|people|name=Julia,age=19")
+ _, err := db.Exec("will never run")
errs <- err
}()
}
- opening.Wait() // wait for all workers to begin running
- time.Sleep(10 * time.Millisecond) // make extra sure all workers are blocked
- close(unblock) // let all workers proceed
+ opening.Wait() // wait for all workers to begin running
- const timeout = 100 * time.Millisecond
+ const timeout = 5 * time.Second
to := time.NewTimer(timeout)
defer to.Stop()
@@ -1254,6 +1713,24 @@ func TestPendingConnsAfterErr(t *testing.T) {
t.Fatalf("orphaned connection request(s), still waiting after %v", timeout)
}
}
+
+ // Wait a reasonable time for the database to close all connections.
+ tick := time.NewTicker(3 * time.Millisecond)
+ defer tick.Stop()
+ for {
+ select {
+ case <-tick.C:
+ db.mu.Lock()
+ if db.numOpen == 0 {
+ db.mu.Unlock()
+ return
+ }
+ db.mu.Unlock()
+ case <-to.C:
+ // Closing the database will check for numOpen and fail the test.
+ return
+ }
+ }
}
func TestSingleOpenConn(t *testing.T) {
@@ -1459,7 +1936,9 @@ func TestStmtCloseDeps(t *testing.T) {
db.dumpDeps(t)
}
- if len(stmt.css) > nquery {
+ if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+ return len(stmt.css) <= nquery
+ }) {
t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
}
@@ -1591,7 +2070,7 @@ func TestStmtCloseOrder(t *testing.T) {
_, err := db.Query("SELECT|non_existent|name|")
if err == nil {
- t.Fatal("Quering non-existent table should fail")
+ t.Fatal("Querying non-existent table should fail")
}
}
@@ -1615,6 +2094,8 @@ func TestManyErrBadConn(t *testing.T) {
}
}()
+ db.mu.Lock()
+ defer db.mu.Unlock()
if db.numOpen != nconn {
t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn)
} else if len(db.freeConn) != nconn {
@@ -2203,10 +2684,10 @@ func TestIssue6081(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- rowsCloseHook = func(rows *Rows, err *error) {
+ setRowsCloseHook(func(rows *Rows, err *error) {
*err = driver.ErrBadConn
- }
- defer func() { rowsCloseHook = nil }()
+ })
+ defer setRowsCloseHook(nil)
for i := 0; i < 10; i++ {
rows, err := stmt.Query()
if err != nil {
@@ -2234,6 +2715,100 @@ func TestIssue6081(t *testing.T) {
}
}
+// TestIssue18429 attempts to stress rolling back the transaction from a
+// context cancel while simultaneously calling Tx.Rollback. Rolling back from a
+// context happens concurrently so tx.rollback and tx.Commit must guard against
+// double entry.
+//
+// In the test, a context is canceled while the query is in process so
+// the internal rollback will run concurrently with the explicitly called
+// Tx.Rollback.
+func TestIssue18429(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx := context.Background()
+ sem := make(chan bool, 20)
+ var wg sync.WaitGroup
+
+ const milliWait = 30
+
+ for i := 0; i < 100; i++ {
+ sem <- true
+ wg.Add(1)
+ go func() {
+ defer func() {
+ <-sem
+ wg.Done()
+ }()
+ qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
+
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ return
+ }
+ // This is expected to give a cancel error many, but not all the time.
+ // Test failure will happen with a panic or other race condition being
+ // reported.
+ rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
+ if rows != nil {
+ rows.Close()
+ }
+ // This call will race with the context cancel rollback to complete
+ // if the rollback itself isn't guarded.
+ tx.Rollback()
+ }()
+ }
+ wg.Wait()
+}
+
+// TestIssue18719 closes the context right before use. The sql.driverConn
+// will nil out the ci on close in a lock, but if another process uses it right after
+// it will panic with on the nil ref.
+//
+// See https://golang.org/cl/35550 .
+func TestIssue18719(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ hookTxGrabConn = func() {
+ cancel()
+
+ // Wait for the context to cancel and tx to rollback.
+ for tx.isDone() == false {
+ time.Sleep(time.Millisecond * 3)
+ }
+ }
+ defer func() { hookTxGrabConn = nil }()
+
+ // This call will grab the connection and cancel the context
+ // after it has done so. Code after must deal with the canceled state.
+ rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ rows.Close()
+ t.Fatalf("expected error %v but got %v", nil, err)
+ }
+
+ // Rows may be ignored because it will be closed when the context is canceled.
+
+ // Do not explicitly rollback. The rollback will happen from the
+ // canceled context.
+
+ cancel()
+ waitForRowsClose(t, rows, 5*time.Second)
+}
+
func TestConcurrency(t *testing.T) {
doConcurrentTest(t, new(concurrentDBQueryTest))
doConcurrentTest(t, new(concurrentDBExecTest))
@@ -2277,7 +2852,8 @@ func TestConnectionLeak(t *testing.T) {
go func() {
r, err := db.Query("SELECT|people|name|")
if err != nil {
- t.Fatal(err)
+ t.Error(err)
+ return
}
r.Close()
wg.Done()
@@ -2297,6 +2873,97 @@ func TestConnectionLeak(t *testing.T) {
wg.Wait()
}
+// badConn implements a bad driver.Conn, for TestBadDriver.
+// The Exec method panics.
+type badConn struct{}
+
+func (bc badConn) Prepare(query string) (driver.Stmt, error) {
+ return nil, errors.New("badConn Prepare")
+}
+
+func (bc badConn) Close() error {
+ return nil
+}
+
+func (bc badConn) Begin() (driver.Tx, error) {
+ return nil, errors.New("badConn Begin")
+}
+
+func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ panic("badConn.Exec")
+}
+
+// badDriver is a driver.Driver that uses badConn.
+type badDriver struct{}
+
+func (bd badDriver) Open(name string) (driver.Conn, error) {
+ return badConn{}, nil
+}
+
+// Issue 15901.
+func TestBadDriver(t *testing.T) {
+ Register("bad", badDriver{})
+ db, err := Open("bad", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("expected panic")
+ } else {
+ if want := "badConn.Exec"; r.(string) != want {
+ t.Errorf("panic was %v, expected %v", r, want)
+ }
+ }
+ }()
+ defer db.Close()
+ db.Exec("ignored")
+}
+
+type pingDriver struct {
+ fails bool
+}
+
+type pingConn struct {
+ badConn
+ driver *pingDriver
+}
+
+var pingError = errors.New("Ping failed")
+
+func (pc pingConn) Ping(ctx context.Context) error {
+ if pc.driver.fails {
+ return pingError
+ }
+ return nil
+}
+
+var _ driver.Pinger = pingConn{}
+
+func (pd *pingDriver) Open(name string) (driver.Conn, error) {
+ return pingConn{driver: pd}, nil
+}
+
+func TestPing(t *testing.T) {
+ driver := &pingDriver{}
+ Register("ping", driver)
+
+ db, err := Open("ping", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := db.Ping(); err != nil {
+ t.Errorf("err was %#v, expected nil", err)
+ return
+ }
+
+ driver.fails = true
+ if err := db.Ping(); err != pingError {
+ t.Errorf("err was %#v, expected pingError", err)
+ }
+}
+
func BenchmarkConcurrentDBExec(b *testing.B) {
b.ReportAllocs()
ct := new(concurrentDBExecTest)