summaryrefslogtreecommitdiff
path: root/src/database/sql/sql_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql/sql_test.go')
-rw-r--r--src/database/sql/sql_test.go125
1 files changed, 125 insertions, 0 deletions
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 2fd81f29a5..b7fdc8eb6c 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -3191,6 +3191,131 @@ func TestConnectionLeak(t *testing.T) {
wg.Wait()
}
+type nvcDriver struct {
+ fakeDriver
+ skipNamedValueCheck bool
+}
+
+func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
+ c, err := d.fakeDriver.Open(dsn)
+ fc := c.(*fakeConn)
+ fc.db.allowAny = true
+ return &nvcConn{fc, d.skipNamedValueCheck}, err
+}
+
+type nvcConn struct {
+ *fakeConn
+ skipNamedValueCheck bool
+}
+
+type decimal struct {
+ value int
+}
+
+type doNotInclude struct{}
+
+var _ driver.NamedValueChecker = &nvcConn{}
+
+func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
+ if c.skipNamedValueCheck {
+ return driver.ErrSkip
+ }
+ switch v := nv.Value.(type) {
+ default:
+ return driver.ErrSkip
+ case Out:
+ switch ov := v.Dest.(type) {
+ default:
+ return errors.New("unkown NameValueCheck OUTPUT type")
+ case *string:
+ *ov = "from-server"
+ nv.Value = "OUT:*string"
+ }
+ return nil
+ case decimal, []int64:
+ return nil
+ case doNotInclude:
+ return driver.ErrRemoveArgument
+ }
+}
+
+func TestNamedValueChecker(t *testing.T) {
+ Register("NamedValueCheck", &nvcDriver{})
+ db, err := Open("NamedValueCheck", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+
+ o1 := ""
+ _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimal{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
+ if err != nil {
+ t.Fatal("exec insert", err)
+ }
+ var (
+ str1 string
+ dec1 decimal
+ arr1 []int64
+ )
+ err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
+ if err != nil {
+ t.Fatal("select", err)
+ }
+
+ list := []struct{ got, want interface{} }{
+ {o1, "from-server"},
+ {dec1, decimal{123}},
+ {str1, "hello"},
+ {arr1, []int64{42, 128, 707}},
+ }
+
+ for index, item := range list {
+ if !reflect.DeepEqual(item.got, item.want) {
+ t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
+ }
+ }
+}
+
+func TestNamedValueCheckerSkip(t *testing.T) {
+ Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
+ db, err := Open("NamedValueCheckSkip", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+
+ _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimal{123}))
+ if err == nil {
+ t.Fatalf("expected error with bad argument, got %v", err)
+ }
+}
+
// badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics.
type badConn struct{}