summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Reinecke <martin@mpa-garching.mpg.de>2019-01-07 16:06:32 +0100
committerMartin Reinecke <martin@mpa-garching.mpg.de>2019-01-08 08:35:49 +0100
commit38b5b7b13d04677a96ce140b6ecb29139f67a453 (patch)
treede16e9c2194d6fc3c4cf522f8b85b48f85d381bc
parent10bf4c63c0a4c5bb06002158fea16af0d0d79085 (diff)
downloadnumpy-38b5b7b13d04677a96ce140b6ecb29139f67a453.tar.gz
fix #12663
tweak and add test better fix fix cleanup, additional test fix test
-rw-r--r--numpy/fft/pocketfft.c18
-rw-r--r--numpy/fft/tests/test_pocketfft.py8
2 files changed, 21 insertions, 5 deletions
diff --git a/numpy/fft/pocketfft.c b/numpy/fft/pocketfft.c
index 10a741b6f..9d1218e6b 100644
--- a/numpy/fft/pocketfft.c
+++ b/numpy/fft/pocketfft.c
@@ -2192,7 +2192,11 @@ WARN_UNUSED_RESULT static int rfft_forward(rfft_plan plan, double c[], double fc
static PyObject *
execute_complex(PyObject *a1, int is_forward, double fct)
{
- PyArrayObject *data = (PyArrayObject *)PyArray_CopyFromObject(a1, NPY_CDOUBLE, 1, 0);
+ PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1,
+ PyArray_DescrFromType(NPY_CDOUBLE), 1, 0,
+ NPY_ARRAY_ENSURECOPY | NPY_ARRAY_DEFAULT |
+ NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST,
+ NULL);
if (!data) return NULL;
int npts = PyArray_DIM(data, PyArray_NDIM(data) - 1);
@@ -2227,8 +2231,10 @@ execute_real_forward(PyObject *a1, double fct)
{
rfft_plan plan=NULL;
int fail = 0;
- PyArrayObject *data = (PyArrayObject *)PyArray_ContiguousFromObject(a1,
- NPY_DOUBLE, 1, 0);
+ PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1,
+ PyArray_DescrFromType(NPY_DOUBLE), 1, 0,
+ NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST,
+ NULL);
if (!data) return NULL;
int ndim = PyArray_NDIM(data);
@@ -2281,8 +2287,10 @@ static PyObject *
execute_real_backward(PyObject *a1, double fct)
{
rfft_plan plan=NULL;
- PyArrayObject *data = (PyArrayObject *)PyArray_ContiguousFromObject(a1,
- NPY_CDOUBLE, 1, 0);
+ PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1,
+ PyArray_DescrFromType(NPY_CDOUBLE), 1, 0,
+ NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST,
+ NULL);
if (!data) return NULL;
int npts = PyArray_DIM(data, PyArray_NDIM(data) - 1);
PyArrayObject *ret = (PyArrayObject *)PyArray_Empty(PyArray_NDIM(data),
diff --git a/numpy/fft/tests/test_pocketfft.py b/numpy/fft/tests/test_pocketfft.py
index 0552f6afd..89e8df271 100644
--- a/numpy/fft/tests/test_pocketfft.py
+++ b/numpy/fft/tests/test_pocketfft.py
@@ -156,6 +156,14 @@ class TestFFT1D(object):
assert_array_almost_equal(x_norm,
np.linalg.norm(tmp))
+ def test_dtypes(self):
+ # make sure that all input precisions are accepted and internally
+ # converted to 64bit
+ for c in "fdg":
+ x = random(30).astype(np.dtype(c))
+ assert_array_almost_equal(np.fft.ifft(np.fft.fft(x)), x)
+ assert_array_almost_equal(np.fft.irfft(np.fft.rfft(x)), x)
+
class TestFFTThreadSafe(object):
threads = 16
input_shape = (800, 200)