summaryrefslogtreecommitdiff
path: root/Modules/_asynciomodule.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_asynciomodule.c')
-rw-r--r--Modules/_asynciomodule.c29
1 files changed, 28 insertions, 1 deletions
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index 01c38b80b9..9ac1c44d48 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -26,6 +26,7 @@ static PyObject *all_tasks;
static PyObject *current_tasks;
static PyObject *traceback_extract_stack;
static PyObject *asyncio_get_event_loop_policy;
+static PyObject *asyncio_iscoroutine_func;
static PyObject *asyncio_future_repr_info_func;
static PyObject *asyncio_task_repr_info_func;
static PyObject *asyncio_task_get_stack_func;
@@ -1461,16 +1462,38 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop)
/*[clinic end generated code: output=9f24774c2287fc2f input=8d132974b049593e]*/
{
PyObject *res;
+ int tmp;
_Py_IDENTIFIER(add);
if (future_init((FutureObj*)self, loop)) {
return -1;
}
+ if (!PyCoro_CheckExact(coro)) {
+ // fastpath failed, perfom slow check
+ // raise after Future.__init__(), attrs are required for __del__
+ res = PyObject_CallFunctionObjArgs(asyncio_iscoroutine_func,
+ coro, NULL);
+ if (res == NULL) {
+ return -1;
+ }
+ tmp = PyObject_Not(res);
+ Py_DECREF(res);
+ if (tmp < 0) {
+ return -1;
+ }
+ if (tmp) {
+ self->task_log_destroy_pending = 0;
+ PyErr_Format(PyExc_TypeError,
+ "a coroutine was expected, got %R",
+ coro, NULL);
+ return -1;
+ }
+ }
+
self->task_fut_waiter = NULL;
self->task_must_cancel = 0;
self->task_log_destroy_pending = 1;
-
Py_INCREF(coro);
self->task_coro = coro;
@@ -2604,6 +2627,7 @@ module_free(void *m)
Py_CLEAR(traceback_extract_stack);
Py_CLEAR(asyncio_get_event_loop_policy);
Py_CLEAR(asyncio_future_repr_info_func);
+ Py_CLEAR(asyncio_iscoroutine_func);
Py_CLEAR(asyncio_task_repr_info_func);
Py_CLEAR(asyncio_task_get_stack_func);
Py_CLEAR(asyncio_task_print_stack_func);
@@ -2645,6 +2669,9 @@ module_init(void)
GET_MOD_ATTR(asyncio_task_get_stack_func, "_task_get_stack")
GET_MOD_ATTR(asyncio_task_print_stack_func, "_task_print_stack")
+ WITH_MOD("asyncio.coroutines")
+ GET_MOD_ATTR(asyncio_iscoroutine_func, "iscoroutine")
+
WITH_MOD("inspect")
GET_MOD_ATTR(inspect_isgenerator, "isgenerator")