diff options
Diffstat (limited to 'Lib/sqlite3')
-rw-r--r-- | Lib/sqlite3/__init__.py | 3 | ||||
-rw-r--r-- | Lib/sqlite3/dbapi2.py | 5 | ||||
-rw-r--r-- | Lib/sqlite3/dump.py | 49 | ||||
-rw-r--r-- | Lib/sqlite3/test/dbapi.py | 82 | ||||
-rw-r--r-- | Lib/sqlite3/test/dump.py | 29 | ||||
-rw-r--r-- | Lib/sqlite3/test/factory.py | 37 | ||||
-rw-r--r-- | Lib/sqlite3/test/hooks.py | 24 | ||||
-rw-r--r-- | Lib/sqlite3/test/regression.py | 163 | ||||
-rw-r--r-- | Lib/sqlite3/test/transactions.py | 20 | ||||
-rw-r--r-- | Lib/sqlite3/test/types.py | 8 | ||||
-rw-r--r-- | Lib/sqlite3/test/userfunctions.py | 78 |
11 files changed, 435 insertions, 63 deletions
diff --git a/Lib/sqlite3/__init__.py b/Lib/sqlite3/__init__.py index 4b64833e93..6c91df27cc 100644 --- a/Lib/sqlite3/__init__.py +++ b/Lib/sqlite3/__init__.py @@ -1,7 +1,6 @@ -#-*- coding: ISO-8859-1 -*- # pysqlite2/__init__.py: the pysqlite2 package. # -# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # diff --git a/Lib/sqlite3/dbapi2.py b/Lib/sqlite3/dbapi2.py index d051f0432f..9a0b76645e 100644 --- a/Lib/sqlite3/dbapi2.py +++ b/Lib/sqlite3/dbapi2.py @@ -1,7 +1,6 @@ -#-*- coding: ISO-8859-1 -*- # pysqlite2/dbapi2.py: the DB-API 2.0 interface # -# Copyright (C) 2004-2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2004-2005 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # @@ -68,7 +67,7 @@ def register_adapters_and_converters(): timepart_full = timepart.split(b".") hours, minutes, seconds = map(int, timepart_full[0].split(b":")) if len(timepart_full) == 2: - microseconds = int(timepart_full[1]) + microseconds = int('{:0<6.6}'.format(timepart_full[1].decode())) else: microseconds = 0 diff --git a/Lib/sqlite3/dump.py b/Lib/sqlite3/dump.py index 409a405cf8..de9c368be3 100644 --- a/Lib/sqlite3/dump.py +++ b/Lib/sqlite3/dump.py @@ -1,6 +1,12 @@ # Mimic the sqlite3 console shell's .dump command # Author: Paul Kippes <kippesp@gmail.com> +# Every identifier in sql is quoted based on a comment in sqlite +# documentation "SQLite adds new keywords from time to time when it +# takes on new features. So to prevent your code from being broken by +# future enhancements, you should normally quote any identifier that +# is an English language word, even if you do not have to." + def _iterdump(connection): """ Returns an iterator to the dump of the database in an SQL text format. @@ -15,49 +21,50 @@ def _iterdump(connection): # sqlite_master table contains the SQL CREATE statements for the database. q = """ - SELECT name, type, sql - FROM sqlite_master - WHERE sql NOT NULL AND - type == 'table' + SELECT "name", "type", "sql" + FROM "sqlite_master" + WHERE "sql" NOT NULL AND + "type" == 'table' + ORDER BY "name" """ schema_res = cu.execute(q) for table_name, type, sql in schema_res.fetchall(): if table_name == 'sqlite_sequence': - yield('DELETE FROM sqlite_sequence;') + yield('DELETE FROM "sqlite_sequence";') elif table_name == 'sqlite_stat1': - yield('ANALYZE sqlite_master;') + yield('ANALYZE "sqlite_master";') elif table_name.startswith('sqlite_'): continue # NOTE: Virtual table support not implemented #elif sql.startswith('CREATE VIRTUAL TABLE'): # qtable = table_name.replace("'", "''") # yield("INSERT INTO sqlite_master(type,name,tbl_name,rootpage,sql)"\ - # "VALUES('table','%s','%s',0,'%s');" % - # qtable, + # "VALUES('table','{0}','{0}',0,'{1}');".format( # qtable, - # sql.replace("''")) + # sql.replace("''"))) else: - yield('%s;' % sql) + yield('{0};'.format(sql)) # Build the insert statement for each row of the current table - res = cu.execute("PRAGMA table_info('%s')" % table_name) + table_name_ident = table_name.replace('"', '""') + res = cu.execute('PRAGMA table_info("{0}")'.format(table_name_ident)) column_names = [str(table_info[1]) for table_info in res.fetchall()] - q = "SELECT 'INSERT INTO \"%(tbl_name)s\" VALUES(" - q += ",".join(["'||quote(" + col + ")||'" for col in column_names]) - q += ")' FROM '%(tbl_name)s'" - query_res = cu.execute(q % {'tbl_name': table_name}) + q = """SELECT 'INSERT INTO "{0}" VALUES({1})' FROM "{0}";""".format( + table_name_ident, + ",".join("""'||quote("{0}")||'""".format(col.replace('"', '""')) for col in column_names)) + query_res = cu.execute(q) for row in query_res: - yield("%s;" % row[0]) + yield("{0};".format(row[0])) # Now when the type is 'index', 'trigger', or 'view' q = """ - SELECT name, type, sql - FROM sqlite_master - WHERE sql NOT NULL AND - type IN ('index', 'trigger', 'view') + SELECT "name", "type", "sql" + FROM "sqlite_master" + WHERE "sql" NOT NULL AND + "type" IN ('index', 'trigger', 'view') """ schema_res = cu.execute(q) for name, type, sql in schema_res.fetchall(): - yield('%s;' % sql) + yield('{0};'.format(sql)) yield('COMMIT;') diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index fbf3072341..202bd38876 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/dbapi.py: tests for DB-API compliance # -# Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2004-2010 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # @@ -22,8 +22,11 @@ # 3. This notice may not be removed or altered from any source distribution. import unittest -import threading import sqlite3 as sqlite +try: + import threading +except ImportError: + threading = None class ModuleTests(unittest.TestCase): def CheckAPILevel(self): @@ -81,6 +84,7 @@ class ModuleTests(unittest.TestCase): "NotSupportedError is not a subclass of DatabaseError") class ConnectionTests(unittest.TestCase): + def setUp(self): self.cx = sqlite.connect(":memory:") cu = self.cx.cursor() @@ -137,6 +141,28 @@ class ConnectionTests(unittest.TestCase): self.assertEqual(self.cx.ProgrammingError, sqlite.ProgrammingError) self.assertEqual(self.cx.NotSupportedError, sqlite.NotSupportedError) + def CheckInTransaction(self): + # Can't use db from setUp because we want to test initial state. + cx = sqlite.connect(":memory:") + cu = cx.cursor() + self.assertEqual(cx.in_transaction, False) + cu.execute("create table transactiontest(id integer primary key, name text)") + self.assertEqual(cx.in_transaction, False) + cu.execute("insert into transactiontest(name) values (?)", ("foo",)) + self.assertEqual(cx.in_transaction, True) + cu.execute("select name from transactiontest where name=?", ["foo"]) + row = cu.fetchone() + self.assertEqual(cx.in_transaction, True) + cx.commit() + self.assertEqual(cx.in_transaction, False) + cu.execute("select name from transactiontest where name=?", ["foo"]) + row = cu.fetchone() + self.assertEqual(cx.in_transaction, False) + + def CheckInTransactionRO(self): + with self.assertRaises(AttributeError): + self.cx.in_transaction = True + class CursorTests(unittest.TestCase): def setUp(self): self.cx = sqlite.connect(":memory:") @@ -199,6 +225,13 @@ class CursorTests(unittest.TestCase): def CheckExecuteArgString(self): self.cu.execute("insert into test(name) values (?)", ("Hugo",)) + def CheckExecuteArgStringWithZeroByte(self): + self.cu.execute("insert into test(name) values (?)", ("Hu\x00go",)) + + self.cu.execute("select name from test where id=?", (self.cu.lastrowid,)) + row = self.cu.fetchone() + self.assertEqual(row[0], "Hu\x00go") + def CheckExecuteWrongNoOfArgs1(self): # too many parameters try: @@ -460,6 +493,7 @@ class CursorTests(unittest.TestCase): except TypeError: pass +@unittest.skipUnless(threading, 'This test requires threading.') class ThreadTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") @@ -653,13 +687,13 @@ class ExtensionTests(unittest.TestCase): res = cur.fetchone()[0] self.assertEqual(res, 5) - def CheckScriptErrorIncomplete(self): + def CheckScriptSyntaxError(self): con = sqlite.connect(":memory:") cur = con.cursor() raised = False try: - cur.executescript("create table test(sadfsadfdsa") - except sqlite.ProgrammingError: + cur.executescript("create table test(x); asdf; create table test2(x)") + except sqlite.OperationalError: raised = True self.assertEqual(raised, True, "should have raised an exception") @@ -692,7 +726,7 @@ class ExtensionTests(unittest.TestCase): result = con.execute("select foo from test").fetchone()[0] self.assertEqual(result, 5, "Basic test of Connection.executescript") -class ClosedTests(unittest.TestCase): +class ClosedConTests(unittest.TestCase): def setUp(self): pass @@ -744,7 +778,6 @@ class ClosedTests(unittest.TestCase): except: self.fail("Should have raised a ProgrammingError") - def CheckClosedCreateFunction(self): con = sqlite.connect(":memory:") con.close() @@ -811,6 +844,36 @@ class ClosedTests(unittest.TestCase): except: self.fail("Should have raised a ProgrammingError") +class ClosedCurTests(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def CheckClosed(self): + con = sqlite.connect(":memory:") + cur = con.cursor() + cur.close() + + for method_name in ("execute", "executemany", "executescript", "fetchall", "fetchmany", "fetchone"): + if method_name in ("execute", "executescript"): + params = ("select 4 union select 5",) + elif method_name == "executemany": + params = ("insert into foo(bar) values (?)", [(3,), (4,)]) + else: + params = [] + + try: + method = getattr(cur, method_name) + + method(*params) + self.fail("Should have raised a ProgrammingError: method " + method_name) + except sqlite.ProgrammingError: + pass + except: + self.fail("Should have raised a ProgrammingError: " + method_name) + def suite(): module_suite = unittest.makeSuite(ModuleTests, "Check") connection_suite = unittest.makeSuite(ConnectionTests, "Check") @@ -818,8 +881,9 @@ def suite(): thread_suite = unittest.makeSuite(ThreadTests, "Check") constructor_suite = unittest.makeSuite(ConstructorTests, "Check") ext_suite = unittest.makeSuite(ExtensionTests, "Check") - closed_suite = unittest.makeSuite(ClosedTests, "Check") - return unittest.TestSuite((module_suite, connection_suite, cursor_suite, thread_suite, constructor_suite, ext_suite, closed_suite)) + closed_con_suite = unittest.makeSuite(ClosedConTests, "Check") + closed_cur_suite = unittest.makeSuite(ClosedCurTests, "Check") + return unittest.TestSuite((module_suite, connection_suite, cursor_suite, thread_suite, constructor_suite, ext_suite, closed_con_suite, closed_cur_suite)) def test(): runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/dump.py b/Lib/sqlite3/test/dump.py index f40876a8e7..a1f45a46dc 100644 --- a/Lib/sqlite3/test/dump.py +++ b/Lib/sqlite3/test/dump.py @@ -13,6 +13,14 @@ class DumpTests(unittest.TestCase): def CheckTableDump(self): expected_sqls = [ + """CREATE TABLE "index"("index" blob);""" + , + """INSERT INTO "index" VALUES(X'01');""" + , + """CREATE TABLE "quoted""table"("quoted""field" text);""" + , + """INSERT INTO "quoted""table" VALUES('quoted''value');""" + , "CREATE TABLE t1(id integer primary key, s1 text, " \ "t1_i1 integer not null, i2 integer, unique (s1), " \ "constraint t1_idx1 unique (i2));" @@ -41,6 +49,27 @@ class DumpTests(unittest.TestCase): [self.assertEqual(expected_sqls[i], actual_sqls[i]) for i in range(len(expected_sqls))] + def CheckUnorderableRow(self): + # iterdump() should be able to cope with unorderable row types (issue #15545) + class UnorderableRow: + def __init__(self, cursor, row): + self.row = row + def __getitem__(self, index): + return self.row[index] + self.cx.row_factory = UnorderableRow + CREATE_ALPHA = """CREATE TABLE "alpha" ("one");""" + CREATE_BETA = """CREATE TABLE "beta" ("two");""" + expected = [ + "BEGIN TRANSACTION;", + CREATE_ALPHA, + CREATE_BETA, + "COMMIT;" + ] + self.cu.execute(CREATE_BETA) + self.cu.execute(CREATE_ALPHA) + got = list(self.cx.iterdump()) + self.assertEqual(expected, got) + def suite(): return unittest.TestSuite(unittest.makeSuite(DumpTests, "Check")) diff --git a/Lib/sqlite3/test/factory.py b/Lib/sqlite3/test/factory.py index 1adab2fef3..7f6f3473f3 100644 --- a/Lib/sqlite3/test/factory.py +++ b/Lib/sqlite3/test/factory.py @@ -189,13 +189,48 @@ class TextFactoryTests(unittest.TestCase): def tearDown(self): self.con.close() +class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.con.execute("create table test (value text)") + self.con.execute("insert into test (value) values (?)", ("a\x00b",)) + + def CheckString(self): + # text_factory defaults to str + row = self.con.execute("select value from test").fetchone() + self.assertIs(type(row[0]), str) + self.assertEqual(row[0], "a\x00b") + + def CheckBytes(self): + self.con.text_factory = bytes + row = self.con.execute("select value from test").fetchone() + self.assertIs(type(row[0]), bytes) + self.assertEqual(row[0], b"a\x00b") + + def CheckBytearray(self): + self.con.text_factory = bytearray + row = self.con.execute("select value from test").fetchone() + self.assertIs(type(row[0]), bytearray) + self.assertEqual(row[0], b"a\x00b") + + def CheckCustom(self): + # A custom factory should receive a bytes argument + self.con.text_factory = lambda x: x + row = self.con.execute("select value from test").fetchone() + self.assertIs(type(row[0]), bytes) + self.assertEqual(row[0], b"a\x00b") + + def tearDown(self): + self.con.close() + def suite(): connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check") cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check") row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check") row_suite = unittest.makeSuite(RowFactoryTests, "Check") text_suite = unittest.makeSuite(TextFactoryTests, "Check") - return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite)) + text_zero_bytes_suite = unittest.makeSuite(TextFactoryTestsWithEmbeddedZeroBytes, "Check") + return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite, text_zero_bytes_suite)) def test(): runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index a6161fac85..9544149ff6 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -47,9 +47,9 @@ class CollationTests(unittest.TestCase): except sqlite.ProgrammingError as e: pass + @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1), + 'old SQLite versions crash on this test') def CheckCollationIsUsed(self): - if sqlite.version_info < (3, 2, 1): # old SQLite versions crash on this test - return def mycoll(x, y): # reverse order return -((x > y) - (x < y)) @@ -76,6 +76,25 @@ class CollationTests(unittest.TestCase): except sqlite.OperationalError as e: self.assertEqual(e.args[0].lower(), "no such collation sequence: mycoll") + def CheckCollationReturnsLargeInteger(self): + def mycoll(x, y): + # reverse order + return -((x > y) - (x < y)) * 2**32 + con = sqlite.connect(":memory:") + con.create_collation("mycoll", mycoll) + sql = """ + select x from ( + select 'a' as x + union + select 'b' as x + union + select 'c' as x + ) order by x collate mycoll + """ + result = con.execute(sql).fetchall() + self.assertEqual(result, [('c',), ('b',), ('a',)], + msg="the expected order was not returned") + def CheckCollationRegisterTwice(self): """ Register two different collation functions under the same name. @@ -168,6 +187,7 @@ class ProgressTests(unittest.TestCase): con = sqlite.connect(":memory:") action = 0 def progress(): + nonlocal action action = 1 return 0 con.set_progress_handler(progress, 1) diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 5e315fa0fb..5e2fbf9435 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -1,7 +1,7 @@ -#-*- coding: ISO-8859-1 -*- +#-*- coding: iso-8859-1 -*- # pysqlite2/test/regression.py: pysqlite regression tests # -# Copyright (C) 2006 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2006-2010 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # @@ -70,16 +70,6 @@ class RegressionTests(unittest.TestCase): cur.execute('select 1 as "foo baz"') self.assertEqual(cur.description[0][0], "foo baz") - def CheckStatementAvailable(self): - # pysqlite up to 2.3.2 crashed on this, because the active statement handle was not checked - # before trying to fetch data from it. close() destroys the active statement ... - con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES) - cur = con.cursor() - cur.execute("select 4 union select 5") - cur.close() - cur.fetchone() - cur.fetchone() - def CheckStatementFinalizationOnCloseDb(self): # pysqlite versions <= 2.3.3 only finalized statements in the statement # cache when closing the database. statements that were still @@ -169,6 +159,25 @@ class RegressionTests(unittest.TestCase): con = sqlite.connect(":memory:") setattr(con, "isolation_level", "\xe9") + def CheckCursorConstructorCallCheck(self): + """ + Verifies that cursor methods check wether base class __init__ was called. + """ + class Cursor(sqlite.Cursor): + def __init__(self, con): + pass + + con = sqlite.connect(":memory:") + cur = Cursor(con) + try: + cur.execute("select 4+5").fetchall() + self.fail("should have raised ProgrammingError") + except sqlite.ProgrammingError: + pass + except: + self.fail("should have raised ProgrammingError") + + def CheckStrSubclass(self): """ The Python 3.0 port of the module didn't cope with values of subclasses of str. @@ -176,6 +185,88 @@ class RegressionTests(unittest.TestCase): class MyStr(str): pass self.con.execute("select ?", (MyStr("abc"),)) + def CheckConnectionConstructorCallCheck(self): + """ + Verifies that connection methods check wether base class __init__ was called. + """ + class Connection(sqlite.Connection): + def __init__(self, name): + pass + + con = Connection(":memory:") + try: + cur = con.cursor() + self.fail("should have raised ProgrammingError") + except sqlite.ProgrammingError: + pass + except: + self.fail("should have raised ProgrammingError") + + def CheckCursorRegistration(self): + """ + Verifies that subclassed cursor classes are correctly registered with + the connection object, too. (fetch-across-rollback problem) + """ + class Connection(sqlite.Connection): + def cursor(self): + return Cursor(self) + + class Cursor(sqlite.Cursor): + def __init__(self, con): + sqlite.Cursor.__init__(self, con) + + con = Connection(":memory:") + cur = con.cursor() + cur.execute("create table foo(x)") + cur.executemany("insert into foo(x) values (?)", [(3,), (4,), (5,)]) + cur.execute("select x from foo") + con.rollback() + try: + cur.fetchall() + self.fail("should have raised InterfaceError") + except sqlite.InterfaceError: + pass + except: + self.fail("should have raised InterfaceError") + + def CheckAutoCommit(self): + """ + Verifies that creating a connection in autocommit mode works. + 2.5.3 introduced a regression so that these could no longer + be created. + """ + con = sqlite.connect(":memory:", isolation_level=None) + + def CheckPragmaAutocommit(self): + """ + Verifies that running a PRAGMA statement that does an autocommit does + work. This did not work in 2.5.3/2.5.4. + """ + cur = self.con.cursor() + cur.execute("create table foo(bar)") + cur.execute("insert into foo(bar) values (5)") + + cur.execute("pragma page_size") + row = cur.fetchone() + + def CheckSetDict(self): + """ + See http://bugs.python.org/issue7478 + + It was possible to successfully register callbacks that could not be + hashed. Return codes of PyDict_SetItem were not checked properly. + """ + class NotHashable: + def __call__(self, *args, **kw): + pass + def __hash__(self): + raise TypeError() + var = NotHashable() + self.assertRaises(TypeError, self.con.create_function, var) + self.assertRaises(TypeError, self.con.create_aggregate, var) + self.assertRaises(TypeError, self.con.set_authorizer, var) + self.assertRaises(TypeError, self.con.set_progress_handler, var) + def CheckConnectionCall(self): """ Call a connection with a non-string SQL request: check error handling @@ -190,6 +281,54 @@ class RegressionTests(unittest.TestCase): # Lone surrogate cannot be encoded to the default encoding (utf8) "\uDC80", collation_cb) + def CheckRecursiveCursorUse(self): + """ + http://bugs.python.org/issue10811 + + Recursively using a cursor, such as when reusing it from a generator led to segfaults. + Now we catch recursive cursor usage and raise a ProgrammingError. + """ + con = sqlite.connect(":memory:") + + cur = con.cursor() + cur.execute("create table a (bar)") + cur.execute("create table b (baz)") + + def foo(): + cur.execute("insert into a (bar) values (?)", (1,)) + yield 1 + + with self.assertRaises(sqlite.ProgrammingError): + cur.executemany("insert into b (baz) values (?)", + ((i,) for i in foo())) + + def CheckConvertTimestampMicrosecondPadding(self): + """ + http://bugs.python.org/issue14720 + + The microsecond parsing of convert_timestamp() should pad with zeros, + since the microsecond string "456" actually represents "456000". + """ + + con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES) + cur = con.cursor() + cur.execute("CREATE TABLE t (x TIMESTAMP)") + + # Microseconds should be 456000 + cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')") + + # Microseconds should be truncated to 123456 + cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')") + + cur.execute("SELECT * FROM t") + values = [x[0] for x in cur.fetchall()] + + self.assertEqual(values, [ + datetime.datetime(2012, 4, 4, 15, 6, 0, 456000), + datetime.datetime(2012, 4, 4, 15, 6, 0, 123456), + ]) + + def suite(): regression_suite = unittest.makeSuite(RegressionTests, "Check") return unittest.TestSuite((regression_suite,)) diff --git a/Lib/sqlite3/test/transactions.py b/Lib/sqlite3/test/transactions.py index c9f6125560..70e96a12ed 100644 --- a/Lib/sqlite3/test/transactions.py +++ b/Lib/sqlite3/test/transactions.py @@ -147,6 +147,26 @@ class TransactionTests(unittest.TestCase): # NO self.con2.rollback() HERE!!! self.con1.commit() + def CheckRollbackCursorConsistency(self): + """ + Checks if cursors on the connection are set into a "reset" state + when a rollback is done on the connection. + """ + con = sqlite.connect(":memory:") + cur = con.cursor() + cur.execute("create table test(x)") + cur.execute("insert into test(x) values (5)") + cur.execute("select 1 union select 2 union select 3") + + con.rollback() + try: + cur.fetchall() + self.fail("InterfaceError should have been raised") + except sqlite.InterfaceError as e: + pass + except: + self.fail("InterfaceError should have been raised") + class SpecialCommandTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 0940e9b28d..29413e14ec 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -21,9 +21,14 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import zlib, datetime +import datetime import unittest import sqlite3 as sqlite +try: + import zlib +except ImportError: + zlib = None + class SqliteTypeTests(unittest.TestCase): def setUp(self): @@ -312,6 +317,7 @@ class ObjectAdaptationTests(unittest.TestCase): val = self.cur.fetchone()[0] self.assertEqual(type(val), float) +@unittest.skipUnless(zlib, "requires zlib") class BinaryConverterTests(unittest.TestCase): def convert(s): return zlib.decompress(s) diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 8bfc591687..9a6a828d81 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -37,6 +37,8 @@ def func_returnnull(): return None def func_returnblob(): return b"blob" +def func_returnlonglong(): + return 1<<31 def func_raiseexception(): 5/0 @@ -50,6 +52,8 @@ def func_isnone(v): return type(v) is type(None) def func_isblob(v): return isinstance(v, (bytes, memoryview)) +def func_islonglong(v): + return isinstance(v, int) and v >= 1<<31 class AggrNoStep: def __init__(self): @@ -127,6 +131,7 @@ class FunctionTests(unittest.TestCase): self.con.create_function("returnfloat", 0, func_returnfloat) self.con.create_function("returnnull", 0, func_returnnull) self.con.create_function("returnblob", 0, func_returnblob) + self.con.create_function("returnlonglong", 0, func_returnlonglong) self.con.create_function("raiseexception", 0, func_raiseexception) self.con.create_function("isstring", 1, func_isstring) @@ -134,6 +139,7 @@ class FunctionTests(unittest.TestCase): self.con.create_function("isfloat", 1, func_isfloat) self.con.create_function("isnone", 1, func_isnone) self.con.create_function("isblob", 1, func_isblob) + self.con.create_function("islonglong", 1, func_islonglong) def tearDown(self): self.con.close() @@ -200,6 +206,12 @@ class FunctionTests(unittest.TestCase): self.assertEqual(type(val), bytes) self.assertEqual(val, b"blob") + def CheckFuncReturnLongLong(self): + cur = self.con.cursor() + cur.execute("select returnlonglong()") + val = cur.fetchone()[0] + self.assertEqual(val, 1<<31) + def CheckFuncException(self): cur = self.con.cursor() try: @@ -239,6 +251,12 @@ class FunctionTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, 1) + def CheckParamLongLong(self): + cur = self.con.cursor() + cur.execute("select islonglong(?)", (1<<42,)) + val = cur.fetchone()[0] + self.assertEqual(val, 1) + class AggregateTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") @@ -357,14 +375,15 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, 60) -def authorizer_cb(action, arg1, arg2, dbname, source): - if action != sqlite.SQLITE_SELECT: - return sqlite.SQLITE_DENY - if arg2 == 'c2' or arg1 == 't2': - return sqlite.SQLITE_DENY - return sqlite.SQLITE_OK - class AuthorizerTests(unittest.TestCase): + @staticmethod + def authorizer_cb(action, arg1, arg2, dbname, source): + if action != sqlite.SQLITE_SELECT: + return sqlite.SQLITE_DENY + if arg2 == 'c2' or arg1 == 't2': + return sqlite.SQLITE_DENY + return sqlite.SQLITE_OK + def setUp(self): self.con = sqlite.connect(":memory:") self.con.executescript(""" @@ -377,12 +396,12 @@ class AuthorizerTests(unittest.TestCase): # For our security test: self.con.execute("select c2 from t2") - self.con.set_authorizer(authorizer_cb) + self.con.set_authorizer(self.authorizer_cb) def tearDown(self): pass - def CheckTableAccess(self): + def test_table_access(self): try: self.con.execute("select * from t2") except sqlite.DatabaseError as e: @@ -391,7 +410,7 @@ class AuthorizerTests(unittest.TestCase): return self.fail("should have raised an exception due to missing privileges") - def CheckColumnAccess(self): + def test_column_access(self): try: self.con.execute("select c2 from t1") except sqlite.DatabaseError as e: @@ -400,11 +419,46 @@ class AuthorizerTests(unittest.TestCase): return self.fail("should have raised an exception due to missing privileges") +class AuthorizerRaiseExceptionTests(AuthorizerTests): + @staticmethod + def authorizer_cb(action, arg1, arg2, dbname, source): + if action != sqlite.SQLITE_SELECT: + raise ValueError + if arg2 == 'c2' or arg1 == 't2': + raise ValueError + return sqlite.SQLITE_OK + +class AuthorizerIllegalTypeTests(AuthorizerTests): + @staticmethod + def authorizer_cb(action, arg1, arg2, dbname, source): + if action != sqlite.SQLITE_SELECT: + return 0.0 + if arg2 == 'c2' or arg1 == 't2': + return 0.0 + return sqlite.SQLITE_OK + +class AuthorizerLargeIntegerTests(AuthorizerTests): + @staticmethod + def authorizer_cb(action, arg1, arg2, dbname, source): + if action != sqlite.SQLITE_SELECT: + return 2**32 + if arg2 == 'c2' or arg1 == 't2': + return 2**32 + return sqlite.SQLITE_OK + + def suite(): function_suite = unittest.makeSuite(FunctionTests, "Check") aggregate_suite = unittest.makeSuite(AggregateTests, "Check") - authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check") - return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite)) + authorizer_suite = unittest.makeSuite(AuthorizerTests) + return unittest.TestSuite(( + function_suite, + aggregate_suite, + authorizer_suite, + unittest.makeSuite(AuthorizerRaiseExceptionTests), + unittest.makeSuite(AuthorizerIllegalTypeTests), + unittest.makeSuite(AuthorizerLargeIntegerTests), + )) def test(): runner = unittest.TextTestRunner() |