diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/code_generators/cversions.txt | 2 | ||||
| -rw-r--r-- | numpy/core/code_generators/numpy_api.py | 1 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/alloc.c | 30 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/alloc.h | 2 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 15 | ||||
| -rw-r--r-- | numpy/core/tests/test_mem_policy.py | 31 |
6 files changed, 44 insertions, 37 deletions
diff --git a/numpy/core/code_generators/cversions.txt b/numpy/core/code_generators/cversions.txt index 38ee4dac2..f0a128d3d 100644 --- a/numpy/core/code_generators/cversions.txt +++ b/numpy/core/code_generators/cversions.txt @@ -59,4 +59,4 @@ 0x0000000e = 17a0f366e55ec05e5c5c149123478452 # Version 15 (NumPy 1.22) Configurable memory allocations -0x0000000f = 0c420aed67010594eb81f23ddfb02a88 +0x0000000f = b8783365b873681cd204be50cdfb448d diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py index 3813c6ad7..d12d62d8f 100644 --- a/numpy/core/code_generators/numpy_api.py +++ b/numpy/core/code_generators/numpy_api.py @@ -19,6 +19,7 @@ from code_generators.genapi import StealRef, NonNull multiarray_global_vars = { 'NPY_NUMUSERTYPES': (7, 'int'), 'NPY_DEFAULT_ASSIGN_CASTING': (292, 'NPY_CASTING'), + 'PyDataMem_DefaultHandler': (306, 'PyObject*'), } multiarray_scalar_bool_values = { diff --git a/numpy/core/src/multiarray/alloc.c b/numpy/core/src/multiarray/alloc.c index d1173410d..0a694cf62 100644 --- a/numpy/core/src/multiarray/alloc.c +++ b/numpy/core/src/multiarray/alloc.c @@ -379,6 +379,8 @@ PyDataMem_Handler default_handler = { default_free /* free */ } }; +/* singleton capsule of the default handler */ +PyObject *PyDataMem_DefaultHandler; #if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600) PyObject *current_handler; @@ -519,16 +521,9 @@ PyDataMem_SetHandler(PyObject *handler) return NULL; } if (handler == NULL) { - handler = PyCapsule_New(&default_handler, "mem_handler", NULL); - if (handler == NULL) { - return NULL; - } - } - else { - Py_INCREF(handler); + handler = PyDataMem_DefaultHandler; } token = PyContextVar_Set(current_handler, handler); - Py_DECREF(handler); if (token == NULL) { Py_DECREF(old_handler); return NULL; @@ -543,26 +538,13 @@ PyDataMem_SetHandler(PyObject *handler) } old_handler = PyDict_GetItemString(p, "current_allocator"); if (old_handler == NULL) { - old_handler = PyCapsule_New(&default_handler, "mem_handler", NULL); - if (old_handler == NULL) { - return NULL; - } - } - else { - Py_INCREF(old_handler); + old_handler = PyDataMem_DefaultHandler } + Py_INCREF(old_handler); if (handler == NULL) { - handler = PyCapsule_New(&default_handler, "mem_handler", NULL); - if (handler == NULL) { - Py_DECREF(old_handler); - return NULL; - } - } - else { - Py_INCREF(handler); + handler = PyDataMem_DefaultHandler; } const int error = PyDict_SetItemString(p, "current_allocator", handler); - Py_DECREF(handler); if (error) { Py_DECREF(old_handler); return NULL; diff --git a/numpy/core/src/multiarray/alloc.h b/numpy/core/src/multiarray/alloc.h index f1ccf0bcd..13c828458 100644 --- a/numpy/core/src/multiarray/alloc.h +++ b/numpy/core/src/multiarray/alloc.h @@ -40,9 +40,9 @@ npy_free_cache_dim_array(PyArrayObject * arr) npy_free_cache_dim(PyArray_DIMS(arr), PyArray_NDIM(arr)); } +extern PyDataMem_Handler default_handler; #if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600) extern PyObject *current_handler; /* PyContextVar/PyCapsule */ -extern PyDataMem_Handler default_handler; #endif NPY_NO_EXPORT PyObject * diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index a854bcb3b..1520ff7ce 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4919,16 +4919,19 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) { if (initumath(m) != 0) { goto err; } -#if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600) /* - * Initialize the context-local PyDataMem_Handler capsule. + * Initialize the default PyDataMem_Handler capsule singleton. */ - c_api = PyCapsule_New(&default_handler, "mem_handler", NULL); - if (c_api == NULL) { + PyDataMem_DefaultHandler = PyCapsule_New(&default_handler, "mem_handler", NULL); + if (PyDataMem_DefaultHandler == NULL) { goto err; } - current_handler = PyContextVar_New("current_allocator", c_api); - Py_DECREF(c_api); +#if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600) + /* + * Initialize the context-local current handler + * with the default PyDataMem_Handler capsule. + */ + current_handler = PyContextVar_New("current_allocator", PyDataMem_DefaultHandler); if (current_handler == NULL) { goto err; } diff --git a/numpy/core/tests/test_mem_policy.py b/numpy/core/tests/test_mem_policy.py index abf340062..3dae36d5a 100644 --- a/numpy/core/tests/test_mem_policy.py +++ b/numpy/core/tests/test_mem_policy.py @@ -19,6 +19,10 @@ def get_module(tmp_path): if sys.platform.startswith('cygwin'): pytest.skip('link fails on cygwin') functions = [ + ("get_default_policy", "METH_NOARGS", """ + Py_INCREF(PyDataMem_DefaultHandler); + return PyDataMem_DefaultHandler; + """), ("set_secret_data_policy", "METH_NOARGS", """ PyObject *secret_data = PyCapsule_New(&secret_data_handler, "mem_handler", NULL); @@ -37,11 +41,7 @@ def get_module(tmp_path): else { old = PyDataMem_SetHandler(NULL); } - if (old == NULL) { - return NULL; - } - Py_DECREF(old); - Py_RETURN_NONE; + return old; """), ("get_array", "METH_NOARGS", """ char *buf = (char *)malloc(20); @@ -238,6 +238,27 @@ def test_set_policy(get_module): assert get_handler_name() == orig_policy_name +def test_default_policy_singleton(get_module): + get_handler_name = np.core.multiarray.get_handler_name + + # set the policy to default + orig_policy = get_module.set_old_policy(None) + + assert get_handler_name() == 'default_allocator' + + # re-set the policy to default + def_policy_1 = get_module.set_old_policy(None) + + assert get_handler_name() == 'default_allocator' + + # set the policy to original + def_policy_2 = get_module.set_old_policy(orig_policy) + + # since default policy is a singleton, + # these should be the same object + assert def_policy_1 is def_policy_2 is get_module.get_default_policy() + + def test_policy_propagation(get_module): # The memory policy goes hand-in-hand with flags.owndata |
