summaryrefslogtreecommitdiff
path: root/Modules/_sqlite/connection.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r--Modules/_sqlite/connection.c231
1 files changed, 197 insertions, 34 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index f65748aa8c..1ce275c2e2 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -1,6 +1,6 @@
/* connection.c - the connection type
*
- * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de>
+ * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
@@ -32,6 +32,9 @@
#include "pythread.h"
+#define ACTION_FINALIZE 1
+#define ACTION_RESET 2
+
static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level);
@@ -51,7 +54,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
{
static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL};
- char* database;
+ PyObject* database;
int detect_types = 0;
PyObject* isolation_level = NULL;
PyObject* factory = NULL;
@@ -59,11 +62,15 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
int cached_statements = 100;
double timeout = 5.0;
int rc;
+ PyObject* class_attr = NULL;
+ PyObject* class_attr_str = NULL;
+ int is_apsw_connection = 0;
+ PyObject* database_utf8;
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist,
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist,
&database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements))
{
- return -1;
+ return -1;
}
self->begin_statement = NULL;
@@ -77,13 +84,53 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
Py_INCREF(&PyUnicode_Type);
self->text_factory = (PyObject*)&PyUnicode_Type;
- Py_BEGIN_ALLOW_THREADS
- rc = sqlite3_open(database, &self->db);
- Py_END_ALLOW_THREADS
+ if (PyString_Check(database) || PyUnicode_Check(database)) {
+ if (PyString_Check(database)) {
+ database_utf8 = database;
+ Py_INCREF(database_utf8);
+ } else {
+ database_utf8 = PyUnicode_AsUTF8String(database);
+ if (!database_utf8) {
+ return -1;
+ }
+ }
- if (rc != SQLITE_OK) {
- _pysqlite_seterror(self->db);
- return -1;
+ Py_BEGIN_ALLOW_THREADS
+ rc = sqlite3_open(PyString_AsString(database_utf8), &self->db);
+ Py_END_ALLOW_THREADS
+
+ Py_DECREF(database_utf8);
+
+ if (rc != SQLITE_OK) {
+ _pysqlite_seterror(self->db, NULL);
+ return -1;
+ }
+ } else {
+ /* Create a pysqlite connection from a APSW connection */
+ class_attr = PyObject_GetAttrString(database, "__class__");
+ if (class_attr) {
+ class_attr_str = PyObject_Str(class_attr);
+ if (class_attr_str) {
+ if (strcmp(PyString_AsString(class_attr_str), "<type 'apsw.Connection'>") == 0) {
+ /* In the APSW Connection object, the first entry after
+ * PyObject_HEAD is the sqlite3* we want to get hold of.
+ * Luckily, this is the same layout as we have in our
+ * pysqlite_Connection */
+ self->db = ((pysqlite_Connection*)database)->db;
+
+ Py_INCREF(database);
+ self->apsw_connection = database;
+ is_apsw_connection = 1;
+ }
+ }
+ }
+ Py_XDECREF(class_attr_str);
+ Py_XDECREF(class_attr);
+
+ if (!is_apsw_connection) {
+ PyErr_SetString(PyExc_ValueError, "database parameter must be string or APSW Connection object");
+ return -1;
+ }
}
if (!isolation_level) {
@@ -169,7 +216,8 @@ void pysqlite_flush_statement_cache(pysqlite_Connection* self)
self->statement_cache->decref_factory = 0;
}
-void pysqlite_reset_all_statements(pysqlite_Connection* self)
+/* action in (ACTION_RESET, ACTION_FINALIZE) */
+void pysqlite_do_all_statements(pysqlite_Connection* self, int action)
{
int i;
PyObject* weakref;
@@ -179,13 +227,19 @@ void pysqlite_reset_all_statements(pysqlite_Connection* self)
weakref = PyList_GetItem(self->statements, i);
statement = PyWeakref_GetObject(weakref);
if (statement != Py_None) {
- (void)pysqlite_statement_reset((pysqlite_Statement*)statement);
+ if (action == ACTION_RESET) {
+ (void)pysqlite_statement_reset((pysqlite_Statement*)statement);
+ } else {
+ (void)pysqlite_statement_finalize((pysqlite_Statement*)statement);
+ }
}
}
}
void pysqlite_connection_dealloc(pysqlite_Connection* self)
{
+ PyObject* ret = NULL;
+
Py_XDECREF(self->statement_cache);
/* Clean up if user has not called .close() explicitly. */
@@ -193,6 +247,10 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
Py_BEGIN_ALLOW_THREADS
sqlite3_close(self->db);
Py_END_ALLOW_THREADS
+ } else if (self->apsw_connection) {
+ ret = PyObject_CallMethod(self->apsw_connection, "close", "");
+ Py_XDECREF(ret);
+ Py_XDECREF(self->apsw_connection);
}
if (self->begin_statement) {
@@ -205,7 +263,7 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
Py_XDECREF(self->collations);
Py_XDECREF(self->statements);
- Py_TYPE(self)->tp_free((PyObject*)self);
+ self->ob_type->tp_free((PyObject*)self);
}
PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
@@ -241,24 +299,33 @@ PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args,
PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
{
+ PyObject* ret;
int rc;
if (!pysqlite_check_thread(self)) {
return NULL;
}
- pysqlite_flush_statement_cache(self);
+ pysqlite_do_all_statements(self, ACTION_FINALIZE);
if (self->db) {
- Py_BEGIN_ALLOW_THREADS
- rc = sqlite3_close(self->db);
- Py_END_ALLOW_THREADS
-
- if (rc != SQLITE_OK) {
- _pysqlite_seterror(self->db);
- return NULL;
- } else {
+ if (self->apsw_connection) {
+ ret = PyObject_CallMethod(self->apsw_connection, "close", "");
+ Py_XDECREF(ret);
+ Py_XDECREF(self->apsw_connection);
+ self->apsw_connection = NULL;
self->db = NULL;
+ } else {
+ Py_BEGIN_ALLOW_THREADS
+ rc = sqlite3_close(self->db);
+ Py_END_ALLOW_THREADS
+
+ if (rc != SQLITE_OK) {
+ _pysqlite_seterror(self->db, NULL);
+ return NULL;
+ } else {
+ self->db = NULL;
+ }
}
}
@@ -292,7 +359,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK) {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, statement);
goto error;
}
@@ -300,7 +367,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)
if (rc == SQLITE_DONE) {
self->inTransaction = 1;
} else {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, statement);
}
Py_BEGIN_ALLOW_THREADS
@@ -308,7 +375,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK && !PyErr_Occurred()) {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, NULL);
}
error:
@@ -335,7 +402,7 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args)
rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail);
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK) {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, NULL);
goto error;
}
@@ -343,14 +410,14 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args)
if (rc == SQLITE_DONE) {
self->inTransaction = 0;
} else {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, statement);
}
Py_BEGIN_ALLOW_THREADS
rc = sqlite3_finalize(statement);
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK && !PyErr_Occurred()) {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, NULL);
}
}
@@ -375,13 +442,13 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args
}
if (self->inTransaction) {
- pysqlite_reset_all_statements(self);
+ pysqlite_do_all_statements(self, ACTION_RESET);
Py_BEGIN_ALLOW_THREADS
rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail);
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK) {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, NULL);
goto error;
}
@@ -389,14 +456,14 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args
if (rc == SQLITE_DONE) {
self->inTransaction = 0;
} else {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, statement);
}
Py_BEGIN_ALLOW_THREADS
rc = sqlite3_finalize(statement);
Py_END_ALLOW_THREADS
if (rc != SQLITE_OK && !PyErr_Occurred()) {
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, NULL);
}
}
@@ -762,6 +829,33 @@ static int _authorizer_callback(void* user_arg, int action, const char* arg1, co
return rc;
}
+static int _progress_handler(void* user_arg)
+{
+ int rc;
+ PyObject *ret;
+ PyGILState_STATE gilstate;
+
+ gilstate = PyGILState_Ensure();
+ ret = PyObject_CallFunction((PyObject*)user_arg, "");
+
+ if (!ret) {
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+
+ /* abort query if error occured */
+ rc = 1;
+ } else {
+ rc = (int)PyObject_IsTrue(ret);
+ }
+
+ Py_DECREF(ret);
+ PyGILState_Release(gilstate);
+ return rc;
+}
+
PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{
PyObject* authorizer_cb;
@@ -787,6 +881,30 @@ PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject
}
}
+PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
+{
+ PyObject* progress_handler;
+ int n;
+
+ static char *kwlist[] = { "progress_handler", "n", NULL };
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler",
+ kwlist, &progress_handler, &n)) {
+ return NULL;
+ }
+
+ if (progress_handler == Py_None) {
+ /* None clears the progress handler previously set */
+ sqlite3_progress_handler(self->db, 0, 0, (void*)0);
+ } else {
+ sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
+ PyDict_SetItem(self->function_pinboard, progress_handler, Py_None);
+ }
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
int pysqlite_check_thread(pysqlite_Connection* self)
{
if (self->check_same_thread) {
@@ -892,7 +1010,8 @@ PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, Py
} else if (rc == PYSQLITE_SQL_WRONG_TYPE) {
PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode.");
} else {
- _pysqlite_seterror(self->db);
+ (void)pysqlite_statement_reset(statement);
+ _pysqlite_seterror(self->db, NULL);
}
Py_DECREF(statement);
@@ -1134,7 +1253,7 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
(callable != Py_None) ? pysqlite_collation_callback : NULL);
if (rc != SQLITE_OK) {
PyDict_DelItem(self->collations, uppercase_name);
- _pysqlite_seterror(self->db);
+ _pysqlite_seterror(self->db, NULL);
goto finally;
}
@@ -1151,6 +1270,44 @@ finally:
return retval;
}
+/* Called when the connection is used as a context manager. Returns itself as a
+ * convenience to the caller. */
+static PyObject *
+pysqlite_connection_enter(pysqlite_Connection* self, PyObject* args)
+{
+ Py_INCREF(self);
+ return (PyObject*)self;
+}
+
+/** Called when the connection is used as a context manager. If there was any
+ * exception, a rollback takes place; otherwise we commit. */
+static PyObject *
+pysqlite_connection_exit(pysqlite_Connection* self, PyObject* args)
+{
+ PyObject* exc_type, *exc_value, *exc_tb;
+ char* method_name;
+ PyObject* result;
+
+ if (!PyArg_ParseTuple(args, "OOO", &exc_type, &exc_value, &exc_tb)) {
+ return NULL;
+ }
+
+ if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) {
+ method_name = "commit";
+ } else {
+ method_name = "rollback";
+ }
+
+ result = PyObject_CallMethod((PyObject*)self, method_name, "");
+ if (!result) {
+ return NULL;
+ }
+ Py_DECREF(result);
+
+ Py_INCREF(Py_False);
+ return Py_False;
+}
+
static char connection_doc[] =
PyDoc_STR("SQLite database connection object.");
@@ -1175,6 +1332,8 @@ static PyMethodDef connection_methods[] = {
PyDoc_STR("Creates a new aggregate. Non-standard.")},
{"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS,
PyDoc_STR("Sets authorizer callback. Non-standard.")},
+ {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS,
+ PyDoc_STR("Sets progress handler callback. Non-standard.")},
{"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS,
PyDoc_STR("Executes a SQL statement. Non-standard.")},
{"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS,
@@ -1185,6 +1344,10 @@ static PyMethodDef connection_methods[] = {
PyDoc_STR("Creates a collation function. Non-standard.")},
{"interrupt", (PyCFunction)pysqlite_connection_interrupt, METH_NOARGS,
PyDoc_STR("Abort any pending database operation. Non-standard.")},
+ {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS,
+ PyDoc_STR("For context manager. Non-standard.")},
+ {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS,
+ PyDoc_STR("For context manager. Non-standard.")},
{NULL, NULL}
};