diff options
Diffstat (limited to 'Modules/_asynciomodule.c')
-rw-r--r-- | Modules/_asynciomodule.c | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 01c38b80b9..9ac1c44d48 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -26,6 +26,7 @@ static PyObject *all_tasks; static PyObject *current_tasks; static PyObject *traceback_extract_stack; static PyObject *asyncio_get_event_loop_policy; +static PyObject *asyncio_iscoroutine_func; static PyObject *asyncio_future_repr_info_func; static PyObject *asyncio_task_repr_info_func; static PyObject *asyncio_task_get_stack_func; @@ -1461,16 +1462,38 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop) /*[clinic end generated code: output=9f24774c2287fc2f input=8d132974b049593e]*/ { PyObject *res; + int tmp; _Py_IDENTIFIER(add); if (future_init((FutureObj*)self, loop)) { return -1; } + if (!PyCoro_CheckExact(coro)) { + // fastpath failed, perfom slow check + // raise after Future.__init__(), attrs are required for __del__ + res = PyObject_CallFunctionObjArgs(asyncio_iscoroutine_func, + coro, NULL); + if (res == NULL) { + return -1; + } + tmp = PyObject_Not(res); + Py_DECREF(res); + if (tmp < 0) { + return -1; + } + if (tmp) { + self->task_log_destroy_pending = 0; + PyErr_Format(PyExc_TypeError, + "a coroutine was expected, got %R", + coro, NULL); + return -1; + } + } + self->task_fut_waiter = NULL; self->task_must_cancel = 0; self->task_log_destroy_pending = 1; - Py_INCREF(coro); self->task_coro = coro; @@ -2604,6 +2627,7 @@ module_free(void *m) Py_CLEAR(traceback_extract_stack); Py_CLEAR(asyncio_get_event_loop_policy); Py_CLEAR(asyncio_future_repr_info_func); + Py_CLEAR(asyncio_iscoroutine_func); Py_CLEAR(asyncio_task_repr_info_func); Py_CLEAR(asyncio_task_get_stack_func); Py_CLEAR(asyncio_task_print_stack_func); @@ -2645,6 +2669,9 @@ module_init(void) GET_MOD_ATTR(asyncio_task_get_stack_func, "_task_get_stack") GET_MOD_ATTR(asyncio_task_print_stack_func, "_task_print_stack") + WITH_MOD("asyncio.coroutines") + GET_MOD_ATTR(asyncio_iscoroutine_func, "iscoroutine") + WITH_MOD("inspect") GET_MOD_ATTR(inspect_isgenerator, "isgenerator") |