summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-04-17 23:44:21 +0100
committerEric Wieser <wieser.eric@gmail.com>2018-11-12 09:14:50 -0800
commitaedc758ad6d1db4f05c7b60f8ac29018c9152999 (patch)
treefbe027e920af7668f0f3b385ed153d8957b8fc83 /numpy/core
parentd3b2036949255e48ecbcfcc70ed2ea95c755cf2a (diff)
downloadnumpy-aedc758ad6d1db4f05c7b60f8ac29018c9152999.tar.gz
ENH: Correct identities for logical ufuncs and logaddexp
Fixes #7702
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/code_generators/generate_umath.py55
-rw-r--r--numpy/core/tests/test_umath.py4
2 files changed, 43 insertions, 16 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 199ad831b..9d4e72c0e 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -10,11 +10,14 @@ sys.path.insert(0, os.path.dirname(__file__))
import ufunc_docstrings as docstrings
sys.path.pop(0)
-Zero = "PyUFunc_Zero"
-One = "PyUFunc_One"
-None_ = "PyUFunc_None"
-AllOnes = "PyUFunc_MinusOne"
-ReorderableNone = "PyUFunc_ReorderableNone"
+Zero = "PyInt_FromLong(0)"
+One = "PyInt_FromLong(1)"
+True_ = "(Py_INCREF(Py_True), Py_True)"
+False_ = "(Py_INCREF(Py_False), Py_False)"
+None_ = object()
+AllOnes = "PyInt_FromLong(-1)"
+MinusInfinity = 'PyFloat_FromDouble(-NPY_INFINITY)'
+ReorderableNone = "(Py_INCREF(Py_None), Py_None)"
# Sentinel value to specify using the full type description in the
# function name
@@ -458,7 +461,7 @@ defdict = {
[TypeDescription('O', FullTypeDescr, 'OO', 'O')],
),
'logical_and':
- Ufunc(2, 1, One,
+ Ufunc(2, 1, True_,
docstrings.get('numpy.core.umath.logical_and'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(nodatetime_or_obj, out='?', simd=[('avx2', ints)]),
@@ -472,14 +475,14 @@ defdict = {
TD(O, f='npy_ObjectLogicalNot'),
),
'logical_or':
- Ufunc(2, 1, Zero,
+ Ufunc(2, 1, False_,
docstrings.get('numpy.core.umath.logical_or'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(nodatetime_or_obj, out='?', simd=[('avx2', ints)]),
TD(O, f='npy_ObjectLogicalOr'),
),
'logical_xor':
- Ufunc(2, 1, Zero,
+ Ufunc(2, 1, False_,
docstrings.get('numpy.core.umath.logical_xor'),
'PyUFunc_SimpleBinaryComparisonTypeResolver',
TD(nodatetime_or_obj, out='?'),
@@ -514,7 +517,7 @@ defdict = {
TD(O, f='npy_ObjectMin')
),
'logaddexp':
- Ufunc(2, 1, None,
+ Ufunc(2, 1, MinusInfinity,
docstrings.get('numpy.core.umath.logaddexp'),
None,
TD(flts, f="logaddexp", astype={'e':'f'})
@@ -1048,18 +1051,38 @@ def make_ufuncs(funcdict):
# do not play well with \n
docstring = '\\n\"\"'.join(docstring.split(r"\n"))
fmt = textwrap.dedent("""\
- f = PyUFunc_FromFuncAndData(
+ identity = {identity_expr};
+ if ({has_identity} && identity == NULL) {{
+ return -1;
+ }}
+ f = PyUFunc_FromFuncAndDataAndSignatureAndIdentity(
{name}_functions, {name}_data, {name}_signatures, {nloops},
{nin}, {nout}, {identity}, "{name}",
- "{doc}", 0
+ "{doc}", 0, NULL, identity
);
+ if ({has_identity}) {{
+ Py_DECREF(identity);
+ }}
if (f == NULL) {{
return -1;
- }}""")
- mlist.append(fmt.format(
+ }}
+ """)
+ args = dict(
name=name, nloops=len(uf.type_descriptions),
- nin=uf.nin, nout=uf.nout, identity=uf.identity, doc=docstring
- ))
+ nin=uf.nin, nout=uf.nout,
+ has_identity='0' if uf.identity is None_ else '1',
+ identity='PyUFunc_IdentityValue',
+ identity_expr=uf.identity,
+ doc=docstring
+ )
+
+ # Only PyUFunc_None means don't reorder - we pass this using the old
+ # argument
+ if uf.identity is None_:
+ args['identity'] = 'PyUFunc_None'
+ args['identity_expr'] = 'NULL'
+
+ mlist.append(fmt.format(**args))
if uf.typereso is not None:
mlist.append(
r"((PyUFuncObject *)f)->type_resolver = &%s;" % uf.typereso)
@@ -1087,7 +1110,7 @@ def make_code(funcdict, filename):
static int
InitOperators(PyObject *dictionary) {
- PyObject *f;
+ PyObject *f, *identity;
%s
%s
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index bd7985dfb..f32c2968f 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -685,6 +685,10 @@ class TestLogAddExp(_FilterInvalids):
assert_(np.isnan(np.logaddexp(0, np.nan)))
assert_(np.isnan(np.logaddexp(np.nan, np.nan)))
+ def test_reduce(self):
+ assert_equal(np.logaddexp.identity, -np.inf)
+ assert_equal(np.logaddexp.reduce([]), -np.inf)
+
class TestLog1p(object):
def test_log1p(self):