diff options
Diffstat (limited to 'libgo/go/database/sql/sql.go')
-rw-r--r-- | libgo/go/database/sql/sql.go | 1083 |
1 files changed, 827 insertions, 256 deletions
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go index d8e7cb77af..c016681fca 100644 --- a/libgo/go/database/sql/sql.go +++ b/libgo/go/database/sql/sql.go @@ -8,15 +8,20 @@ // The sql package must be used in conjunction with a database driver. // See https://golang.org/s/sqldrivers for a list of drivers. // -// For more usage examples, see the wiki page at +// Drivers that do not support context cancelation will not return until +// after the query is completed. +// +// For usage examples, see the wiki page at // https://golang.org/s/sqlwiki. package sql import ( + "context" "database/sql/driver" "errors" "fmt" "io" + "reflect" "runtime" "sort" "sync" @@ -66,6 +71,75 @@ func Drivers() []string { return list } +// A NamedArg is a named argument. NamedArg values may be used as +// arguments to Query or Exec and bind to the corresponding named +// parameter in the SQL statement. +// +// For a more concise way to create NamedArg values, see +// the Named function. +type NamedArg struct { + _Named_Fields_Required struct{} + + // Name is the name of the parameter placeholder. + // + // If empty, the ordinal position in the argument list will be + // used. + // + // Name must omit any symbol prefix. + Name string + + // Value is the value of the parameter. + // It may be assigned the same value types as the query + // arguments. + Value interface{} +} + +// Named provides a more concise way to create NamedArg values. +// +// Example usage: +// +// db.ExecContext(ctx, ` +// delete from Invoice +// where +// TimeCreated < @end +// and TimeCreated >= @start;`, +// sql.Named("start", startTime), +// sql.Named("end", endTime), +// ) +func Named(name string, value interface{}) NamedArg { + // This method exists because the go1compat promise + // doesn't guarantee that structs don't grow more fields, + // so unkeyed struct literals are a vet error. Thus, we don't + // want to allow sql.NamedArg{name, value}. + return NamedArg{Name: name, Value: value} +} + +// IsolationLevel is the transaction isolation level used in TxOptions. +type IsolationLevel int + +// Various isolation levels that drivers may support in BeginTx. +// If a driver does not support a given isolation level an error may be returned. +// +// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels. +const ( + LevelDefault IsolationLevel = iota + LevelReadUncommitted + LevelReadCommitted + LevelWriteCommitted + LevelRepeatableRead + LevelSnapshot + LevelSerializable + LevelLinearizable +) + +// TxOptions holds the transaction options to be used in DB.BeginTx. +type TxOptions struct { + // Isolation is the transaction isolation level. + // If zero, the driver or database's default level is used. + Isolation IsolationLevel + ReadOnly bool +} + // RawBytes is a byte slice that holds a reference to memory owned by // the database itself. After a Scan into a RawBytes, the slice is only // valid until the next call to Next, Scan, or Close. @@ -199,7 +273,7 @@ type Scanner interface { // time.Time // nil - for NULL values // - // An error should be returned if the value can not be stored + // An error should be returned if the value cannot be stored // without loss of information. Scan(src interface{}) error } @@ -231,8 +305,9 @@ type DB struct { mu sync.Mutex // protects following fields freeConn []*driverConn - connRequests []chan connRequest - numOpen int // number of opened and pending open connections + connRequests map[uint64]chan connRequest + nextRequest uint64 // Next key to use in connRequests. + numOpen int // number of opened and pending open connections // Used to signal the need for new connections // a goroutine running connectionOpener() reads on this chan and // maybeOpenNewConnections sends on the chan (one send per needed connection) @@ -272,7 +347,7 @@ type driverConn struct { ci driver.Conn closed bool finalClosed bool // ci.Close has been called - openStmt map[driver.Stmt]bool + openStmt map[*driverStmt]bool // guarded by db.mu inUse bool @@ -284,10 +359,10 @@ func (dc *driverConn) releaseConn(err error) { dc.db.putConn(dc, err) } -func (dc *driverConn) removeOpenStmt(si driver.Stmt) { +func (dc *driverConn) removeOpenStmt(ds *driverStmt) { dc.Lock() defer dc.Unlock() - delete(dc.openStmt, si) + delete(dc.openStmt, ds) } func (dc *driverConn) expired(timeout time.Duration) bool { @@ -297,28 +372,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { - si, err := dc.ci.Prepare(query) - if err == nil { - // Track each driverConn's open statements, so we can close them - // before closing the conn. - // - // TODO(bradfitz): let drivers opt out of caring about - // stmt closes if the conn is about to close anyway? For now - // do the safe thing, in case stmts need to be closed. - // - // TODO(bradfitz): after Go 1.2, closing driver.Stmts - // should be moved to driverStmt, using unique - // *driverStmts everywhere (including from - // *Stmt.connStmt, instead of returning a - // driver.Stmt), using driverStmt as a pointer - // everywhere, and making it a finalCloser. - if dc.openStmt == nil { - dc.openStmt = make(map[driver.Stmt]bool) - } - dc.openStmt[si] = true +func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) { + si, err := ctxDriverPrepare(ctx, dc.ci, query) + if err != nil { + return nil, err } - return si, err + + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once. + if dc.openStmt == nil { + dc.openStmt = make(map[*driverStmt]bool) + } + ds := &driverStmt{Locker: dc, si: si} + dc.openStmt[ds] = true + + return ds, nil } // the dc.db's Mutex is held. @@ -350,17 +420,26 @@ func (dc *driverConn) Close() error { } func (dc *driverConn) finalClose() error { - dc.Lock() + var err error - for si := range dc.openStmt { - si.Close() + // Each *driverStmt has a lock to the dc. Copy the list out of the dc + // before calling close on each stmt. + var openStmt []*driverStmt + withLock(dc, func() { + openStmt = make([]*driverStmt, 0, len(dc.openStmt)) + for ds := range dc.openStmt { + openStmt = append(openStmt, ds) + } + dc.openStmt = nil + }) + for _, ds := range openStmt { + ds.Close() } - dc.openStmt = nil - - err := dc.ci.Close() - dc.ci = nil - dc.finalClosed = true - dc.Unlock() + withLock(dc, func() { + dc.finalClosed = true + err = dc.ci.Close() + dc.ci = nil + }) dc.db.mu.Lock() dc.db.numOpen-- @@ -377,12 +456,21 @@ func (dc *driverConn) finalClose() error { type driverStmt struct { sync.Locker // the *driverConn si driver.Stmt + closed bool + closeErr error // return value of previous Close call } +// Close ensures dirver.Stmt is only closed once any always returns the same +// result. func (ds *driverStmt) Close() error { ds.Lock() defer ds.Unlock() - return ds.si.Close() + if ds.closed { + return ds.closeErr + } + ds.closed = true + ds.closeErr = ds.si.Close() + return ds.closeErr } // depSet is a finalCloser's outstanding dependencies @@ -485,27 +573,46 @@ func Open(driverName, dataSourceName string) (*DB, error) { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } db := &DB{ - driver: driveri, - dsn: dataSourceName, - openerCh: make(chan struct{}, connectionRequestQueueSize), - lastPut: make(map[*driverConn]string), + driver: driveri, + dsn: dataSourceName, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), + connRequests: make(map[uint64]chan connRequest), } go db.connectionOpener() return db, nil } -// Ping verifies a connection to the database is still alive, +// PingContext verifies a connection to the database is still alive, // establishing a connection if necessary. -func (db *DB) Ping() error { - // TODO(bradfitz): give drivers an optional hook to implement - // this in a more efficient or more reliable way, if they - // have one. - dc, err := db.conn(cachedOrNewConn) +func (db *DB) PingContext(ctx context.Context) error { + var dc *driverConn + var err error + + for i := 0; i < maxBadConnRetries; i++ { + dc, err = db.conn(ctx, cachedOrNewConn) + if err != driver.ErrBadConn { + break + } + } + if err == driver.ErrBadConn { + dc, err = db.conn(ctx, alwaysNewConn) + } if err != nil { return err } - db.putConn(dc, nil) - return nil + + if pinger, ok := dc.ci.(driver.Pinger); ok { + err = pinger.Ping(ctx) + } + db.putConn(dc, err) + return err +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) Ping() error { + return db.PingContext(context.Background()) } // Close closes the database, releasing any open resources. @@ -718,6 +825,9 @@ func (db *DB) maybeOpenNewConnections() { for numRequests > 0 { db.numOpen++ // optimistically numRequests-- + if db.closed { + return + } db.openerCh <- struct{}{} } } @@ -773,13 +883,28 @@ type connRequest struct { var errDBClosed = errors.New("sql: database is closed") +// nextRequestKeyLocked returns the next connection request key. +// It is assumed that nextRequest will not overflow. +func (db *DB) nextRequestKeyLocked() uint64 { + next := db.nextRequest + db.nextRequest++ + return next +} + // conn returns a newly-opened or cached *driverConn. -func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { +func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() return nil, errDBClosed } + // Check if the context is expired. + select { + default: + case <-ctx.Done(): + db.mu.Unlock() + return nil, ctx.Err() + } lifetime := db.maxLifetime // Prefer a free connection, if possible. @@ -797,23 +922,42 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { return conn, nil } - // Out of free connections or we were asked not to use one. If we're not + // Out of free connections or we were asked not to use one. If we're not // allowed to open any more connections, make a request and wait. if db.maxOpen > 0 && db.numOpen >= db.maxOpen { // Make the connRequest channel. It's buffered so that the // connectionOpener doesn't block while waiting for the req to be read. req := make(chan connRequest, 1) - db.connRequests = append(db.connRequests, req) + reqKey := db.nextRequestKeyLocked() + db.connRequests[reqKey] = req db.mu.Unlock() - ret, ok := <-req - if !ok { - return nil, errDBClosed - } - if ret.err == nil && ret.conn.expired(lifetime) { - ret.conn.Close() - return nil, driver.ErrBadConn + + // Timeout the connection request with the context. + select { + case <-ctx.Done(): + // Remove the connection request and ensure no value has been sent + // on it after removing. + db.mu.Lock() + delete(db.connRequests, reqKey) + db.mu.Unlock() + select { + default: + case ret, ok := <-req: + if ok { + db.putConn(ret.conn, ret.err) + } + } + return nil, ctx.Err() + case ret, ok := <-req: + if !ok { + return nil, errDBClosed + } + if ret.err == nil && ret.conn.expired(lifetime) { + ret.conn.Close() + return nil, driver.ErrBadConn + } + return ret.conn, ret.err } - return ret.conn, ret.err } db.numOpen++ // optimistically @@ -838,29 +982,25 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { return dc, nil } -var ( - errConnClosed = errors.New("database/sql: internal sentinel error: conn is closed") - errConnBusy = errors.New("database/sql: internal sentinel error: conn is busy") -) - // putConnHook is a hook for testing. var putConnHook func(*DB, *driverConn) -// noteUnusedDriverStatement notes that si is no longer used and should +// noteUnusedDriverStatement notes that ds is no longer used and should // be closed whenever possible (when c is next not in use), unless c is // already closed. -func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { +func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) { db.mu.Lock() defer db.mu.Unlock() if c.inUse { c.onPut = append(c.onPut, func() { - si.Close() + ds.Close() }) } else { c.Lock() - defer c.Unlock() - if !c.finalClosed { - si.Close() + fc := c.finalClosed + c.Unlock() + if !fc { + ds.Close() } } } @@ -920,16 +1060,19 @@ func (db *DB) putConn(dc *driverConn, err error) { // If a connRequest was fulfilled or the *driverConn was placed in the // freeConn list, then true is returned, otherwise false is returned. func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { + if db.closed { + return false + } if db.maxOpen > 0 && db.numOpen > db.maxOpen { return false } if c := len(db.connRequests); c > 0 { - req := db.connRequests[0] - // This copy is O(n) but in practice faster than a linked list. - // TODO: consider compacting it down less often and - // moving the base instead? - copy(db.connRequests, db.connRequests[1:]) - db.connRequests = db.connRequests[:c-1] + var req chan connRequest + var reqKey uint64 + for reqKey, req = range db.connRequests { + break + } + delete(db.connRequests, reqKey) // Remove from pending requests. if err == nil { dc.inUse = true } @@ -951,40 +1094,53 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { // connection to be opened. const maxBadConnRetries = 2 -// Prepare creates a prepared statement for later queries or executions. +// PrepareContext creates a prepared statement for later queries or executions. // Multiple queries or executions may be run concurrently from the // returned statement. // The caller must call the statement's Close method // when the statement is no longer needed. -func (db *DB) Prepare(query string) (*Stmt, error) { +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt var err error for i := 0; i < maxBadConnRetries; i++ { - stmt, err = db.prepare(query, cachedOrNewConn) + stmt, err = db.prepare(ctx, query, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.prepare(query, alwaysNewConn) + return db.prepare(ctx, query, alwaysNewConn) } return stmt, err } -func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) { // TODO: check if db.driver supports an optional // driver.Preparer interface and call that instead, if so, // otherwise we make a prepared statement that's bound // to a connection, and to execute this prepared statement // we either need to use this connection (if it's free), else // get a new connection + re-prepare + execute on that one. - dc, err := db.conn(strategy) + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } - dc.Lock() - si, err := dc.prepareLocked(query) - dc.Unlock() + var ds *driverStmt + withLock(dc, func() { + ds, err = dc.prepareLocked(ctx, query) + }) if err != nil { db.putConn(dc, err) return nil, err @@ -992,7 +1148,7 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { stmt := &Stmt{ db: db, query: query, - css: []connStmt{{dc, si}}, + css: []connStmt{{dc, ds}}, lastNumClosed: atomic.LoadUint64(&db.numClosed), } db.addDep(stmt, stmt) @@ -1000,25 +1156,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { return stmt, nil } -// Exec executes a query without returning any rows. +// ExecContext executes a query without returning any rows. // The args are for any placeholder parameters in the query. -func (db *DB) Exec(query string, args ...interface{}) (Result, error) { +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { var res Result var err error for i := 0; i < maxBadConnRetries; i++ { - res, err = db.exec(query, args, cachedOrNewConn) + res, err = db.exec(ctx, query, args, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.exec(query, args, alwaysNewConn) + return db.exec(ctx, query, args, alwaysNewConn) } return res, err } -func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { - dc, err := db.conn(strategy) +// Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (db *DB) Exec(query string, args ...interface{}) (Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } @@ -1027,13 +1189,15 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) }() if execer, ok := dc.ci.(driver.Execer); ok { - dargs, err := driverArgs(nil, args) + var dargs []driver.NamedValue + dargs, err = driverArgs(nil, args) if err != nil { return nil, err } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() + var resi driver.Result + withLock(dc, func() { + resi, err = ctxDriverExec(ctx, execer, query, dargs) + }) if err != driver.ErrSkip { if err != nil { return nil, err @@ -1042,54 +1206,63 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) - return resultFromStatement(driverStmt{dc, si}, args...) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() + return resultFromStatement(ctx, ds, args...) } -// Query executes a query that returns rows, typically a SELECT. +// QueryContext executes a query that returns rows, typically a SELECT. // The args are for any placeholder parameters in the query. -func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { var rows *Rows var err error for i := 0; i < maxBadConnRetries; i++ { - rows, err = db.query(query, args, cachedOrNewConn) + rows, err = db.query(ctx, query, args, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.query(query, args, alwaysNewConn) + return db.query(ctx, query, args, alwaysNewConn) } return rows, err } -func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { - ci, err := db.conn(strategy) +// Query executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { + ci, err := db.conn(ctx, strategy) if err != nil { return nil, err } - return db.queryConn(ci, ci.releaseConn, query, args) + return db.queryConn(ctx, ci, ci.releaseConn, query, args) } // queryConn executes a query on the given connection. // The connection gets released by the releaseConn function. -func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { +func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { if queryer, ok := dc.ci.(driver.Queryer); ok { dargs, err := driverArgs(nil, args) if err != nil { releaseConn(err) return nil, err } - dc.Lock() - rowsi, err := queryer.Query(query, dargs) - dc.Unlock() + var rowsi driver.Rows + withLock(dc, func() { + rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs) + }) if err != driver.ErrSkip { if err != nil { releaseConn(err) @@ -1102,24 +1275,25 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a releaseConn: releaseConn, rowsi: rowsi, } + rows.initContextClose(ctx) return rows, nil } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + var err error + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { releaseConn(err) return nil, err } - ds := driverStmt{dc, si} - rowsi, err := rowsiFromStatement(ds, args...) + ds := &driverStmt{Locker: dc, si: si} + rowsi, err := rowsiFromStatement(ctx, ds, args...) if err != nil { - dc.Lock() - si.Close() - dc.Unlock() + ds.Close() releaseConn(err) return nil, err } @@ -1130,53 +1304,84 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a dc: dc, releaseConn: releaseConn, rowsi: rowsi, - closeStmt: si, + closeStmt: ds, } + rows.initContextClose(ctx) return rows, nil } +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (db *DB) QueryRow(query string, args ...interface{}) *Row { - rows, err := db.Query(query, args...) - return &Row{rows: rows, err: err} + return db.QueryRowContext(context.Background(), query, args...) } -// Begin starts a transaction. The isolation level is dependent on -// the driver. -func (db *DB) Begin() (*Tx, error) { +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { var tx *Tx var err error for i := 0; i < maxBadConnRetries; i++ { - tx, err = db.begin(cachedOrNewConn) + tx, err = db.begin(ctx, opts, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.begin(alwaysNewConn) + return db.begin(ctx, opts, alwaysNewConn) } return tx, err } -func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) { - dc, err := db.conn(strategy) +// Begin starts a transaction. The default isolation level is dependent on +// the driver. +func (db *DB) Begin() (*Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) { + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } - dc.Lock() - txi, err := dc.ci.Begin() - dc.Unlock() + var txi driver.Tx + withLock(dc, func() { + txi, err = ctxDriverBegin(ctx, opts, dc.ci) + }) if err != nil { db.putConn(dc, err) return nil, err } - return &Tx{ - db: db, - dc: dc, - txi: txi, - }, nil + + // Schedule the transaction to rollback when the context is cancelled. + // The cancel function in Tx will be called after done is set to true. + ctx, cancel := context.WithCancel(ctx) + tx = &Tx{ + db: db, + dc: dc, + txi: txi, + cancel: cancel, + ctx: ctx, + } + go tx.awaitDone() + return tx, nil } // Driver returns the database's underlying driver. @@ -1197,60 +1402,112 @@ func (db *DB) Driver() driver.Driver { type Tx struct { db *DB + // closemu prevents the transaction from closing while there + // is an active query. It is held for read during queries + // and exclusively during close. + closemu sync.RWMutex + // dc is owned exclusively until Commit or Rollback, at which point // it's returned with putConn. dc *driverConn txi driver.Tx - // done transitions from false to true exactly once, on Commit + // done transitions from 0 to 1 exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. - done bool + // Use atomic operations on value when checking value. + done int32 - // All Stmts prepared for this transaction. These will be closed after the + // All Stmts prepared for this transaction. These will be closed after the // transaction has been committed or rolled back. stmts struct { sync.Mutex v []*Stmt } + + // cancel is called after done transitions from false to true. + cancel func() + + // ctx lives for the life of the transaction. + ctx context.Context +} + +// awaitDone blocks until the context in Tx is canceled and rolls back +// the transaction if it's not already done. +func (tx *Tx) awaitDone() { + // Wait for either the transaction to be committed or rolled + // back, or for the associated context to be closed. + <-tx.ctx.Done() + + // Discard and close the connection used to ensure the + // transaction is closed and the resources are released. This + // rollback does nothing if the transaction has already been + // committed or rolled back. + tx.rollback(true) +} + +func (tx *Tx) isDone() bool { + return atomic.LoadInt32(&tx.done) != 0 } +// ErrTxDone is returned by any operation that is performed on a transaction +// that has already been committed or rolled back. var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") +// close returns the connection to the pool and +// must only be called by Tx.rollback or Tx.Commit. func (tx *Tx) close(err error) { - if tx.done { - panic("double close") // internal error - } - tx.done = true + tx.closemu.Lock() + defer tx.closemu.Unlock() + tx.db.putConn(tx.dc, err) + tx.cancel() tx.dc = nil tx.txi = nil } -func (tx *Tx) grabConn() (*driverConn, error) { - if tx.done { +// hookTxGrabConn specifies an optional hook to be called on +// a successful call to (*Tx).grabConn. For tests. +var hookTxGrabConn func() + +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + if tx.isDone() { return nil, ErrTxDone } + if hookTxGrabConn != nil { // test hook + hookTxGrabConn() + } return tx.dc, nil } // Closes all Stmts prepared for this transaction. func (tx *Tx) closePrepared() { tx.stmts.Lock() + defer tx.stmts.Unlock() for _, stmt := range tx.stmts.v { stmt.Close() } - tx.stmts.Unlock() } // Commit commits the transaction. func (tx *Tx) Commit() error { - if tx.done { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } - tx.dc.Lock() - err := tx.txi.Commit() - tx.dc.Unlock() + select { + default: + case <-tx.ctx.Done(): + return tx.ctx.Err() + } + var err error + withLock(tx.dc, func() { + err = tx.txi.Commit() + }) if err != driver.ErrBadConn { tx.closePrepared() } @@ -1258,49 +1515,67 @@ func (tx *Tx) Commit() error { return err } -// Rollback aborts the transaction. -func (tx *Tx) Rollback() error { - if tx.done { +// rollback aborts the transaction and optionally forces the pool to discard +// the connection. +func (tx *Tx) rollback(discardConn bool) error { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } - tx.dc.Lock() - err := tx.txi.Rollback() - tx.dc.Unlock() + var err error + withLock(tx.dc, func() { + err = tx.txi.Rollback() + }) if err != driver.ErrBadConn { tx.closePrepared() } + if discardConn { + err = driver.ErrBadConn + } tx.close(err) return err } +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + return tx.rollback(false) +} + // Prepare creates a prepared statement for use within a transaction. // -// The returned statement operates within the transaction and can no longer -// be used once the transaction has been committed or rolled back. +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. // // To use an existing prepared statement on this transaction, see Tx.Stmt. -func (tx *Tx) Prepare(query string) (*Stmt, error) { +// +// The provided context will be used for the preparation of the context, not +// for the execution of the returned statement. The returned statement +// will run in the transaction context. +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + tx.closemu.RLock() + defer tx.closemu.RUnlock() + // TODO(bradfitz): We could be more efficient here and either // provide a method to take an existing Stmt (created on // perhaps a different Conn), and re-create it on this Conn if // necessary. Or, better: keep a map in DB of query string to // Stmts, and have Stmt.Execute do the right thing and // re-prepare if the Conn in use doesn't have that prepared - // statement. But we'll want to avoid caching the statement + // statement. But we'll want to avoid caching the statement // in the case where we only call conn.Prepare implicitly // (such as in db.Exec or tx.Exec), but the caller package // can't be holding a reference to the returned statement. // Perhaps just looking at the reference count (by noting // Stmt.Close) would be enough. We might also want a finalizer // on Stmt to drop the reference count. - dc, err := tx.grabConn() + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { return nil, err } @@ -1308,7 +1583,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { stmt := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1320,7 +1595,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { return stmt, nil } -// Stmt returns a transaction-specific prepared statement from +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +// StmtContext returns a transaction-specific prepared statement from // an existing statement. // // Example: @@ -1328,30 +1613,34 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // ... // tx, err := db.Begin() // ... -// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) // -// The returned statement operates within the transaction and can no longer -// be used once the transaction has been committed or rolled back. -func (tx *Tx) Stmt(stmt *Stmt) *Stmt { +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + tx.closemu.RLock() + defer tx.closemu.RUnlock() + // TODO(bradfitz): optimize this. Currently this re-prepares - // each time. This is fine for now to illustrate the API but + // each time. This is fine for now to illustrate the API but // we should really cache already-prepared statements // per-Conn. See also the big comment in Tx.Prepare. if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - dc, err := tx.grabConn() + dc, err := tx.grabConn(ctx) if err != nil { return &Stmt{stickyErr: err} } - dc.Lock() - si, err := dc.ci.Prepare(stmt.query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) txs := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1364,10 +1653,29 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { return txs } -// Exec executes a query that doesn't return rows. +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +// ExecContext executes a query that doesn't return rows. // For example: an INSERT and UPDATE. -func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - dc, err := tx.grabConn() +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + tx.closemu.RLock() + defer tx.closemu.RUnlock() + + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } @@ -1377,9 +1685,10 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { if err != nil { return nil, err } - dc.Lock() - resi, err := execer.Exec(query, dargs) - dc.Unlock() + var resi driver.Result + withLock(dc, func() { + resi, err = ctxDriverExec(ctx, execer, query, dargs) + }) if err == nil { return driverResult{dc, resi}, nil } @@ -1388,39 +1697,62 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { } } - dc.Lock() - si, err := dc.ci.Prepare(query) - dc.Unlock() + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() - return resultFromStatement(driverStmt{dc, si}, args...) + return resultFromStatement(ctx, ds, args...) } -// Query executes a query that returns rows, typically a SELECT. -func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - dc, err := tx.grabConn() +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { + return tx.ExecContext(context.Background(), query, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + tx.closemu.RLock() + defer tx.closemu.RUnlock() + + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } releaseConn := func(error) {} - return tx.db.queryConn(dc, releaseConn, query, args) + return tx.db.queryConn(ctx, dc, releaseConn, query, args) +} + +// Query executes a query that returns rows, typically a SELECT. +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} } // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - rows, err := tx.Query(query, args...) - return &Row{rows: rows, err: err} + return tx.QueryRowContext(context.Background(), query, args...) } // connStmt is a prepared statement on a particular connection. type connStmt struct { dc *driverConn - si driver.Stmt + ds *driverStmt } // Stmt is a prepared statement. @@ -1435,15 +1767,15 @@ type Stmt struct { // If in a transaction, else both nil: tx *Tx - txsi *driverStmt + txds *driverStmt mu sync.Mutex // protects the rest of the fields closed bool // css is a list of underlying driver statement interfaces - // that are valid on particular connections. This is only + // that are valid on particular connections. This is only // used if tx == nil and one is found that has idle - // connections. If tx != nil, txsi is always used. + // connections. If tx != nil, txsi is always used. css []connStmt // lastNumClosed is copied from db.numClosed when Stmt is created @@ -1451,15 +1783,15 @@ type Stmt struct { lastNumClosed uint64 } -// Exec executes a prepared statement with the given arguments and +// ExecContext executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. -func (s *Stmt) Exec(args ...interface{}) (Result, error) { +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) { s.closemu.RLock() defer s.closemu.RUnlock() var res Result for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt() + _, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1467,7 +1799,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, err } - res, err = resultFromStatement(driverStmt{dc, si}, args...) + res, err = resultFromStatement(ctx, ds, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -1476,13 +1808,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, driver.ErrBadConn } -func driverNumInput(ds driverStmt) int { +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...interface{}) (Result, error) { + return s.ExecContext(context.Background(), args...) +} + +func driverNumInput(ds *driverStmt) int { ds.Lock() defer ds.Unlock() // in case NumInput panics return ds.si.NumInput() } -func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { +func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) { want := driverNumInput(ds) // -1 means the driver doesn't know how to count the number of @@ -1492,14 +1830,15 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } ds.Lock() defer ds.Unlock() - resi, err := ds.si.Exec(dargs) + + resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) if err != nil { return nil, err } @@ -1535,7 +1874,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) { if err = s.stickyErr; err != nil { return } @@ -1550,19 +1889,18 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St // transaction was created on. if s.tx != nil { s.mu.Unlock() - ci, err = s.tx.grabConn() // blocks, waiting for the connection. + ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection. if err != nil { return } releaseConn = func(error) {} - return ci, releaseConn, s.txsi.si, nil + return ci, releaseConn, s.txds, nil } s.removeClosedStmtLocked() s.mu.Unlock() - // TODO(bradfitz): or always wait for one? make configurable later? - dc, err := s.db.conn(cachedOrNewConn) + dc, err := s.db.conn(ctx, cachedOrNewConn) if err != nil { return nil, nil, nil, err } @@ -1571,36 +1909,36 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St for _, v := range s.css { if v.dc == dc { s.mu.Unlock() - return dc, dc.releaseConn, v.si, nil + return dc, dc.releaseConn, v.ds, nil } } s.mu.Unlock() // No luck; we need to prepare the statement on this connection - dc.Lock() - si, err = dc.prepareLocked(s.query) - dc.Unlock() + withLock(dc, func() { + ds, err = dc.prepareLocked(ctx, s.query) + }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err } s.mu.Lock() - cs := connStmt{dc, si} + cs := connStmt{dc, ds} s.css = append(s.css, cs) s.mu.Unlock() - return dc, dc.releaseConn, si, nil + return dc, dc.releaseConn, ds, nil } -// Query executes a prepared query statement with the given arguments +// QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. -func (s *Stmt) Query(args ...interface{}) (*Rows, error) { +func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { s.closemu.RLock() defer s.closemu.RUnlock() var rowsi driver.Rows for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt() + dc, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1608,7 +1946,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return nil, err } - rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...) + rowsi, err = rowsiFromStatement(ctx, ds, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -1617,6 +1955,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { rowsi: rowsi, // releaseConn set below } + rows.initContextClose(ctx) s.db.addDep(s, rows) rows.releaseConn = func(err error) { releaseConn(err) @@ -1633,10 +1972,17 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return nil, driver.ErrBadConn } -func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { - ds.Lock() - want := ds.si.NumInput() - ds.Unlock() +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { + var want int + withLock(ds, func() { + want = ds.si.NumInput() + }) // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the @@ -1645,21 +1991,22 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } ds.Lock() - rowsi, err := ds.si.Query(dargs) - ds.Unlock() + defer ds.Unlock() + + rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) if err != nil { return nil, err } return rowsi, nil } -// QueryRow executes a prepared query statement with the given arguments. +// QueryRowContext executes a prepared query statement with the given arguments. // If an error occurs during the execution of the statement, that error will // be returned by a call to Scan on the returned *Row, which is always non-nil. // If the query selects no rows, the *Row's Scan will return ErrNoRows. @@ -1669,15 +2016,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) // Example usage: // // var name string -// err := nameByUseridStmt.QueryRow(id).Scan(&name) -func (s *Stmt) QueryRow(args ...interface{}) *Row { - rows, err := s.Query(args...) +// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name) +func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { + rows, err := s.QueryContext(ctx, args...) if err != nil { return &Row{err: err} } return &Row{rows: rows} } +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +func (s *Stmt) QueryRow(args ...interface{}) *Row { + return s.QueryRowContext(context.Background(), args...) +} + // Close closes the statement. func (s *Stmt) Close() error { s.closemu.Lock() @@ -1692,13 +2054,11 @@ func (s *Stmt) Close() error { return nil } s.closed = true + s.mu.Unlock() if s.tx != nil { - err := s.txsi.Close() - s.mu.Unlock() - return err + return s.txds.Close() } - s.mu.Unlock() return s.db.removeDep(s, s) } @@ -1708,8 +2068,8 @@ func (s *Stmt) finalClose() error { defer s.mu.Unlock() if s.css != nil { for _, v := range s.css { - s.db.noteUnusedDriverStatement(v.dc, v.si) - v.dc.removeOpenStmt(v.si) + s.db.noteUnusedDriverStatement(v.dc, v.ds) + v.dc.removeOpenStmt(v.ds) } s.css = nil } @@ -1734,29 +2094,110 @@ type Rows struct { dc *driverConn // owned; must call releaseConn when closed to release releaseConn func(error) rowsi driver.Rows + cancel func() // called when Rows is closed, may be nil. + closeStmt *driverStmt // if non-nil, statement to Close on close + + // closemu prevents Rows from closing while there + // is an active streaming result. It is held for read during non-close operations + // and exclusively during close. + // + // closemu guards lasterr and closed. + closemu sync.RWMutex + closed bool + lasterr error // non-nil only if closed is true - closed bool - lastcols []driver.Value - lasterr error // non-nil only if closed is true - closeStmt driver.Stmt // if non-nil, statement to Close on close + // lastcols is only used in Scan, Next, and NextResultSet which are expected + // not not be called concurrently. + lastcols []driver.Value } -// Next prepares the next result row for reading with the Scan method. It +func (rs *Rows) initContextClose(ctx context.Context) { + ctx, rs.cancel = context.WithCancel(ctx) + go rs.awaitDone(ctx) +} + +// awaitDone blocks until the rows are closed or the context canceled. +func (rs *Rows) awaitDone(ctx context.Context) { + <-ctx.Done() + rs.close(ctx.Err()) +} + +// Next prepares the next result row for reading with the Scan method. It // returns true on success, or false if there is no next result row or an error -// happened while preparing it. Err should be consulted to distinguish between +// happened while preparing it. Err should be consulted to distinguish between // the two cases. // // Every call to Scan, even the first one, must be preceded by a call to Next. func (rs *Rows) Next() bool { + var doClose, ok bool + withLock(rs.closemu.RLocker(), func() { + doClose, ok = rs.nextLocked() + }) + if doClose { + rs.Close() + } + return ok +} + +func (rs *Rows) nextLocked() (doClose, ok bool) { if rs.closed { - return false + return false, false } if rs.lastcols == nil { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) } rs.lasterr = rs.rowsi.Next(rs.lastcols) if rs.lasterr != nil { - rs.Close() + // Close the connection if there is a driver error. + if rs.lasterr != io.EOF { + return true, false + } + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + return true, false + } + // The driver is at the end of the current result set. + // Test to see if there is another result set after the current one. + // Only close Rows if there is no further result sets to read. + if !nextResultSet.HasNextResultSet() { + doClose = true + } + return doClose, false + } + return false, true +} + +// NextResultSet prepares the next result set for reading. It returns true if +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it. The Err method should be consulted +// to distinguish between the two cases. +// +// After calling NextResultSet, the Next method should always be called before +// scanning. If there are further result sets they may not have rows in the result +// set. +func (rs *Rows) NextResultSet() bool { + var doClose bool + defer func() { + if doClose { + rs.Close() + } + }() + rs.closemu.RLock() + defer rs.closemu.RUnlock() + + if rs.closed { + return false + } + + rs.lastcols = nil + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + doClose = true + return false + } + rs.lasterr = nextResultSet.NextResultSet() + if rs.lasterr != nil { + doClose = true return false } return true @@ -1765,6 +2206,8 @@ func (rs *Rows) Next() bool { // Err returns the error, if any, that was encountered during iteration. // Err may be called after an explicit or implicit Close. func (rs *Rows) Err() error { + rs.closemu.RLock() + defer rs.closemu.RUnlock() if rs.lasterr == io.EOF { return nil } @@ -1775,6 +2218,8 @@ func (rs *Rows) Err() error { // Columns returns an error if the rows are closed, or if the rows // are from QueryRow and there was a deferred error. func (rs *Rows) Columns() ([]string, error) { + rs.closemu.RLock() + defer rs.closemu.RUnlock() if rs.closed { return nil, errors.New("sql: Rows are closed") } @@ -1784,6 +2229,109 @@ func (rs *Rows) Columns() ([]string, error) { return rs.rowsi.Columns(), nil } +// ColumnTypes returns column information such as column type, length, +// and nullable. Some information may not be available from some drivers. +func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { + return nil, errors.New("sql: Rows are closed") + } + if rs.rowsi == nil { + return nil, errors.New("sql: no Rows available") + } + return rowsColumnInfoSetup(rs.rowsi), nil +} + +// ColumnType contains the name and type of a column. +type ColumnType struct { + name string + + hasNullable bool + hasLength bool + hasPrecisionScale bool + + nullable bool + length int64 + databaseType string + precision int64 + scale int64 + scanType reflect.Type +} + +// Name returns the name or alias of the column. +func (ci *ColumnType) Name() string { + return ci.name +} + +// Length returns the column type length for variable length column types such +// as text and binary field types. If the type length is unbounded the value will +// be math.MaxInt64 (any database limits will still apply). +// If the column type is not variable length, such as an int, or if not supported +// by the driver ok is false. +func (ci *ColumnType) Length() (length int64, ok bool) { + return ci.length, ci.hasLength +} + +// DecimalSize returns the scale and precision of a decimal type. +// If not applicable or if not supported ok is false. +func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) { + return ci.precision, ci.scale, ci.hasPrecisionScale +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +// If a driver does not support this property ScanType will return +// the type of an empty interface. +func (ci *ColumnType) ScanType() reflect.Type { + return ci.scanType +} + +// Nullable returns whether the column may be null. +// If a driver does not support this property ok will be false. +func (ci *ColumnType) Nullable() (nullable, ok bool) { + return ci.nullable, ci.hasNullable +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT". +func (ci *ColumnType) DatabaseTypeName() string { + return ci.databaseType +} + +func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { + names := rowsi.Columns() + + list := make([]*ColumnType, len(names)) + for i := range list { + ci := &ColumnType{ + name: names[i], + } + list[i] = ci + + if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok { + ci.scanType = prop.ColumnTypeScanType(i) + } else { + ci.scanType = reflect.TypeOf(new(interface{})).Elem() + } + if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok { + ci.databaseType = prop.ColumnTypeDatabaseTypeName(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok { + ci.length, ci.hasLength = prop.ColumnTypeLength(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok { + ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok { + ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i) + } + } + return list +} + // Scan copies the columns in the current row into the values pointed // at by dest. The number of values in dest must be the same as the // number of columns in Rows. @@ -1836,9 +2384,13 @@ func (rs *Rows) Columns() ([]string, error) { // For scanning into *bool, the source may be true, false, 1, 0, or // string inputs parseable by strconv.ParseBool. func (rs *Rows) Scan(dest ...interface{}) error { + rs.closemu.RLock() if rs.closed { + rs.closemu.RUnlock() return errors.New("sql: Rows are closed") } + rs.closemu.RUnlock() + if rs.lastcols == nil { return errors.New("sql: Scan called without calling Next") } @@ -1854,20 +2406,39 @@ func (rs *Rows) Scan(dest ...interface{}) error { return nil } -var rowsCloseHook func(*Rows, *error) +// rowsCloseHook returns a function so tests may install the +// hook throug a test only mutex. +var rowsCloseHook = func() func(*Rows, *error) { return nil } -// Close closes the Rows, preventing further enumeration. If Next returns -// false, the Rows are closed automatically and it will suffice to check the +// Close closes the Rows, preventing further enumeration. If Next is called +// and returns false and there are no further result sets, +// the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { + return rs.close(nil) +} + +func (rs *Rows) close(err error) error { + rs.closemu.Lock() + defer rs.closemu.Unlock() + if rs.closed { return nil } rs.closed = true - err := rs.rowsi.Close() - if fn := rowsCloseHook; fn != nil { + + if rs.lasterr == nil { + rs.lasterr = err + } + + err = rs.rowsi.Close() + if fn := rowsCloseHook(); fn != nil { fn(rs, &err) } + if rs.cancel != nil { + rs.cancel() + } + if rs.closeStmt != nil { rs.closeStmt.Close() } @@ -1898,8 +2469,8 @@ func (r *Row) Scan(dest ...interface{}) error { // the Rows in our defer, when we return from this function. // the contract with the driver.Next(...) interface is that it // can return slices into read-only temporary memory that's - // only valid until the next Scan/Close. But the TODO is that - // for a lot of drivers, this copy will be unnecessary. We + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We // should provide an optional interface for drivers to // implement to say, "don't worry, the []bytes that I return // from Next will not be modified again." (for instance, if |