summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/code_generators/cversions.txt2
-rw-r--r--numpy/core/code_generators/numpy_api.py1
-rw-r--r--numpy/core/src/multiarray/alloc.c30
-rw-r--r--numpy/core/src/multiarray/alloc.h2
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c15
-rw-r--r--numpy/core/tests/test_mem_policy.py31
6 files changed, 44 insertions, 37 deletions
diff --git a/numpy/core/code_generators/cversions.txt b/numpy/core/code_generators/cversions.txt
index 38ee4dac2..f0a128d3d 100644
--- a/numpy/core/code_generators/cversions.txt
+++ b/numpy/core/code_generators/cversions.txt
@@ -59,4 +59,4 @@
0x0000000e = 17a0f366e55ec05e5c5c149123478452
# Version 15 (NumPy 1.22) Configurable memory allocations
-0x0000000f = 0c420aed67010594eb81f23ddfb02a88
+0x0000000f = b8783365b873681cd204be50cdfb448d
diff --git a/numpy/core/code_generators/numpy_api.py b/numpy/core/code_generators/numpy_api.py
index 3813c6ad7..d12d62d8f 100644
--- a/numpy/core/code_generators/numpy_api.py
+++ b/numpy/core/code_generators/numpy_api.py
@@ -19,6 +19,7 @@ from code_generators.genapi import StealRef, NonNull
multiarray_global_vars = {
'NPY_NUMUSERTYPES': (7, 'int'),
'NPY_DEFAULT_ASSIGN_CASTING': (292, 'NPY_CASTING'),
+ 'PyDataMem_DefaultHandler': (306, 'PyObject*'),
}
multiarray_scalar_bool_values = {
diff --git a/numpy/core/src/multiarray/alloc.c b/numpy/core/src/multiarray/alloc.c
index d1173410d..0a694cf62 100644
--- a/numpy/core/src/multiarray/alloc.c
+++ b/numpy/core/src/multiarray/alloc.c
@@ -379,6 +379,8 @@ PyDataMem_Handler default_handler = {
default_free /* free */
}
};
+/* singleton capsule of the default handler */
+PyObject *PyDataMem_DefaultHandler;
#if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600)
PyObject *current_handler;
@@ -519,16 +521,9 @@ PyDataMem_SetHandler(PyObject *handler)
return NULL;
}
if (handler == NULL) {
- handler = PyCapsule_New(&default_handler, "mem_handler", NULL);
- if (handler == NULL) {
- return NULL;
- }
- }
- else {
- Py_INCREF(handler);
+ handler = PyDataMem_DefaultHandler;
}
token = PyContextVar_Set(current_handler, handler);
- Py_DECREF(handler);
if (token == NULL) {
Py_DECREF(old_handler);
return NULL;
@@ -543,26 +538,13 @@ PyDataMem_SetHandler(PyObject *handler)
}
old_handler = PyDict_GetItemString(p, "current_allocator");
if (old_handler == NULL) {
- old_handler = PyCapsule_New(&default_handler, "mem_handler", NULL);
- if (old_handler == NULL) {
- return NULL;
- }
- }
- else {
- Py_INCREF(old_handler);
+ old_handler = PyDataMem_DefaultHandler
}
+ Py_INCREF(old_handler);
if (handler == NULL) {
- handler = PyCapsule_New(&default_handler, "mem_handler", NULL);
- if (handler == NULL) {
- Py_DECREF(old_handler);
- return NULL;
- }
- }
- else {
- Py_INCREF(handler);
+ handler = PyDataMem_DefaultHandler;
}
const int error = PyDict_SetItemString(p, "current_allocator", handler);
- Py_DECREF(handler);
if (error) {
Py_DECREF(old_handler);
return NULL;
diff --git a/numpy/core/src/multiarray/alloc.h b/numpy/core/src/multiarray/alloc.h
index f1ccf0bcd..13c828458 100644
--- a/numpy/core/src/multiarray/alloc.h
+++ b/numpy/core/src/multiarray/alloc.h
@@ -40,9 +40,9 @@ npy_free_cache_dim_array(PyArrayObject * arr)
npy_free_cache_dim(PyArray_DIMS(arr), PyArray_NDIM(arr));
}
+extern PyDataMem_Handler default_handler;
#if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600)
extern PyObject *current_handler; /* PyContextVar/PyCapsule */
-extern PyDataMem_Handler default_handler;
#endif
NPY_NO_EXPORT PyObject *
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index a854bcb3b..1520ff7ce 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4919,16 +4919,19 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
if (initumath(m) != 0) {
goto err;
}
-#if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600)
/*
- * Initialize the context-local PyDataMem_Handler capsule.
+ * Initialize the default PyDataMem_Handler capsule singleton.
*/
- c_api = PyCapsule_New(&default_handler, "mem_handler", NULL);
- if (c_api == NULL) {
+ PyDataMem_DefaultHandler = PyCapsule_New(&default_handler, "mem_handler", NULL);
+ if (PyDataMem_DefaultHandler == NULL) {
goto err;
}
- current_handler = PyContextVar_New("current_allocator", c_api);
- Py_DECREF(c_api);
+#if (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x07030600)
+ /*
+ * Initialize the context-local current handler
+ * with the default PyDataMem_Handler capsule.
+ */
+ current_handler = PyContextVar_New("current_allocator", PyDataMem_DefaultHandler);
if (current_handler == NULL) {
goto err;
}
diff --git a/numpy/core/tests/test_mem_policy.py b/numpy/core/tests/test_mem_policy.py
index abf340062..3dae36d5a 100644
--- a/numpy/core/tests/test_mem_policy.py
+++ b/numpy/core/tests/test_mem_policy.py
@@ -19,6 +19,10 @@ def get_module(tmp_path):
if sys.platform.startswith('cygwin'):
pytest.skip('link fails on cygwin')
functions = [
+ ("get_default_policy", "METH_NOARGS", """
+ Py_INCREF(PyDataMem_DefaultHandler);
+ return PyDataMem_DefaultHandler;
+ """),
("set_secret_data_policy", "METH_NOARGS", """
PyObject *secret_data =
PyCapsule_New(&secret_data_handler, "mem_handler", NULL);
@@ -37,11 +41,7 @@ def get_module(tmp_path):
else {
old = PyDataMem_SetHandler(NULL);
}
- if (old == NULL) {
- return NULL;
- }
- Py_DECREF(old);
- Py_RETURN_NONE;
+ return old;
"""),
("get_array", "METH_NOARGS", """
char *buf = (char *)malloc(20);
@@ -238,6 +238,27 @@ def test_set_policy(get_module):
assert get_handler_name() == orig_policy_name
+def test_default_policy_singleton(get_module):
+ get_handler_name = np.core.multiarray.get_handler_name
+
+ # set the policy to default
+ orig_policy = get_module.set_old_policy(None)
+
+ assert get_handler_name() == 'default_allocator'
+
+ # re-set the policy to default
+ def_policy_1 = get_module.set_old_policy(None)
+
+ assert get_handler_name() == 'default_allocator'
+
+ # set the policy to original
+ def_policy_2 = get_module.set_old_policy(orig_policy)
+
+ # since default policy is a singleton,
+ # these should be the same object
+ assert def_policy_1 is def_policy_2 is get_module.get_default_policy()
+
+
def test_policy_propagation(get_module):
# The memory policy goes hand-in-hand with flags.owndata