diff options
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r-- | Modules/_sqlite/connection.c | 231 |
1 files changed, 197 insertions, 34 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index f65748aa8c..1ce275c2e2 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1,6 +1,6 @@ /* connection.c - the connection type * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -32,6 +32,9 @@ #include "pythread.h" +#define ACTION_FINALIZE 1 +#define ACTION_RESET 2 + static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level); @@ -51,7 +54,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject { static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL}; - char* database; + PyObject* database; int detect_types = 0; PyObject* isolation_level = NULL; PyObject* factory = NULL; @@ -59,11 +62,15 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject int cached_statements = 100; double timeout = 5.0; int rc; + PyObject* class_attr = NULL; + PyObject* class_attr_str = NULL; + int is_apsw_connection = 0; + PyObject* database_utf8; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist, &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements)) { - return -1; + return -1; } self->begin_statement = NULL; @@ -77,13 +84,53 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject Py_INCREF(&PyUnicode_Type); self->text_factory = (PyObject*)&PyUnicode_Type; - Py_BEGIN_ALLOW_THREADS - rc = sqlite3_open(database, &self->db); - Py_END_ALLOW_THREADS + if (PyString_Check(database) || PyUnicode_Check(database)) { + if (PyString_Check(database)) { + database_utf8 = database; + Py_INCREF(database_utf8); + } else { + database_utf8 = PyUnicode_AsUTF8String(database); + if (!database_utf8) { + return -1; + } + } - if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); - return -1; + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_open(PyString_AsString(database_utf8), &self->db); + Py_END_ALLOW_THREADS + + Py_DECREF(database_utf8); + + if (rc != SQLITE_OK) { + _pysqlite_seterror(self->db, NULL); + return -1; + } + } else { + /* Create a pysqlite connection from a APSW connection */ + class_attr = PyObject_GetAttrString(database, "__class__"); + if (class_attr) { + class_attr_str = PyObject_Str(class_attr); + if (class_attr_str) { + if (strcmp(PyString_AsString(class_attr_str), "<type 'apsw.Connection'>") == 0) { + /* In the APSW Connection object, the first entry after + * PyObject_HEAD is the sqlite3* we want to get hold of. + * Luckily, this is the same layout as we have in our + * pysqlite_Connection */ + self->db = ((pysqlite_Connection*)database)->db; + + Py_INCREF(database); + self->apsw_connection = database; + is_apsw_connection = 1; + } + } + } + Py_XDECREF(class_attr_str); + Py_XDECREF(class_attr); + + if (!is_apsw_connection) { + PyErr_SetString(PyExc_ValueError, "database parameter must be string or APSW Connection object"); + return -1; + } } if (!isolation_level) { @@ -169,7 +216,8 @@ void pysqlite_flush_statement_cache(pysqlite_Connection* self) self->statement_cache->decref_factory = 0; } -void pysqlite_reset_all_statements(pysqlite_Connection* self) +/* action in (ACTION_RESET, ACTION_FINALIZE) */ +void pysqlite_do_all_statements(pysqlite_Connection* self, int action) { int i; PyObject* weakref; @@ -179,13 +227,19 @@ void pysqlite_reset_all_statements(pysqlite_Connection* self) weakref = PyList_GetItem(self->statements, i); statement = PyWeakref_GetObject(weakref); if (statement != Py_None) { - (void)pysqlite_statement_reset((pysqlite_Statement*)statement); + if (action == ACTION_RESET) { + (void)pysqlite_statement_reset((pysqlite_Statement*)statement); + } else { + (void)pysqlite_statement_finalize((pysqlite_Statement*)statement); + } } } } void pysqlite_connection_dealloc(pysqlite_Connection* self) { + PyObject* ret = NULL; + Py_XDECREF(self->statement_cache); /* Clean up if user has not called .close() explicitly. */ @@ -193,6 +247,10 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self) Py_BEGIN_ALLOW_THREADS sqlite3_close(self->db); Py_END_ALLOW_THREADS + } else if (self->apsw_connection) { + ret = PyObject_CallMethod(self->apsw_connection, "close", ""); + Py_XDECREF(ret); + Py_XDECREF(self->apsw_connection); } if (self->begin_statement) { @@ -205,7 +263,7 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self) Py_XDECREF(self->collations); Py_XDECREF(self->statements); - Py_TYPE(self)->tp_free((PyObject*)self); + self->ob_type->tp_free((PyObject*)self); } PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -241,24 +299,33 @@ PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args) { + PyObject* ret; int rc; if (!pysqlite_check_thread(self)) { return NULL; } - pysqlite_flush_statement_cache(self); + pysqlite_do_all_statements(self, ACTION_FINALIZE); if (self->db) { - Py_BEGIN_ALLOW_THREADS - rc = sqlite3_close(self->db); - Py_END_ALLOW_THREADS - - if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); - return NULL; - } else { + if (self->apsw_connection) { + ret = PyObject_CallMethod(self->apsw_connection, "close", ""); + Py_XDECREF(ret); + Py_XDECREF(self->apsw_connection); + self->apsw_connection = NULL; self->db = NULL; + } else { + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_close(self->db); + Py_END_ALLOW_THREADS + + if (rc != SQLITE_OK) { + _pysqlite_seterror(self->db, NULL); + return NULL; + } else { + self->db = NULL; + } } } @@ -292,7 +359,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); goto error; } @@ -300,7 +367,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) if (rc == SQLITE_DONE) { self->inTransaction = 1; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS @@ -308,7 +375,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } error: @@ -335,7 +402,7 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args) rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail); Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto error; } @@ -343,14 +410,14 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args) if (rc == SQLITE_DONE) { self->inTransaction = 0; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS rc = sqlite3_finalize(statement); Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } } @@ -375,13 +442,13 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args } if (self->inTransaction) { - pysqlite_reset_all_statements(self); + pysqlite_do_all_statements(self, ACTION_RESET); Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail); Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto error; } @@ -389,14 +456,14 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args if (rc == SQLITE_DONE) { self->inTransaction = 0; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS rc = sqlite3_finalize(statement); Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } } @@ -762,6 +829,33 @@ static int _authorizer_callback(void* user_arg, int action, const char* arg1, co return rc; } +static int _progress_handler(void* user_arg) +{ + int rc; + PyObject *ret; + PyGILState_STATE gilstate; + + gilstate = PyGILState_Ensure(); + ret = PyObject_CallFunction((PyObject*)user_arg, ""); + + if (!ret) { + if (_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + + /* abort query if error occured */ + rc = 1; + } else { + rc = (int)PyObject_IsTrue(ret); + } + + Py_DECREF(ret); + PyGILState_Release(gilstate); + return rc; +} + PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) { PyObject* authorizer_cb; @@ -787,6 +881,30 @@ PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject } } +PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) +{ + PyObject* progress_handler; + int n; + + static char *kwlist[] = { "progress_handler", "n", NULL }; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler", + kwlist, &progress_handler, &n)) { + return NULL; + } + + if (progress_handler == Py_None) { + /* None clears the progress handler previously set */ + sqlite3_progress_handler(self->db, 0, 0, (void*)0); + } else { + sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); + PyDict_SetItem(self->function_pinboard, progress_handler, Py_None); + } + + Py_INCREF(Py_None); + return Py_None; +} + int pysqlite_check_thread(pysqlite_Connection* self) { if (self->check_same_thread) { @@ -892,7 +1010,8 @@ PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, Py } else if (rc == PYSQLITE_SQL_WRONG_TYPE) { PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode."); } else { - _pysqlite_seterror(self->db); + (void)pysqlite_statement_reset(statement); + _pysqlite_seterror(self->db, NULL); } Py_DECREF(statement); @@ -1134,7 +1253,7 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) (callable != Py_None) ? pysqlite_collation_callback : NULL); if (rc != SQLITE_OK) { PyDict_DelItem(self->collations, uppercase_name); - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto finally; } @@ -1151,6 +1270,44 @@ finally: return retval; } +/* Called when the connection is used as a context manager. Returns itself as a + * convenience to the caller. */ +static PyObject * +pysqlite_connection_enter(pysqlite_Connection* self, PyObject* args) +{ + Py_INCREF(self); + return (PyObject*)self; +} + +/** Called when the connection is used as a context manager. If there was any + * exception, a rollback takes place; otherwise we commit. */ +static PyObject * +pysqlite_connection_exit(pysqlite_Connection* self, PyObject* args) +{ + PyObject* exc_type, *exc_value, *exc_tb; + char* method_name; + PyObject* result; + + if (!PyArg_ParseTuple(args, "OOO", &exc_type, &exc_value, &exc_tb)) { + return NULL; + } + + if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) { + method_name = "commit"; + } else { + method_name = "rollback"; + } + + result = PyObject_CallMethod((PyObject*)self, method_name, ""); + if (!result) { + return NULL; + } + Py_DECREF(result); + + Py_INCREF(Py_False); + return Py_False; +} + static char connection_doc[] = PyDoc_STR("SQLite database connection object."); @@ -1175,6 +1332,8 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Creates a new aggregate. Non-standard.")}, {"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Sets authorizer callback. Non-standard.")}, + {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS, + PyDoc_STR("Sets progress handler callback. Non-standard.")}, {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS, PyDoc_STR("Executes a SQL statement. Non-standard.")}, {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS, @@ -1185,6 +1344,10 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Creates a collation function. Non-standard.")}, {"interrupt", (PyCFunction)pysqlite_connection_interrupt, METH_NOARGS, PyDoc_STR("Abort any pending database operation. Non-standard.")}, + {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS, + PyDoc_STR("For context manager. Non-standard.")}, + {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS, + PyDoc_STR("For context manager. Non-standard.")}, {NULL, NULL} }; |