diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-01-08 21:29:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-01-08 21:29:43 +0200 |
commit | 9405d2b87e75de1d27b7866c4acf0e18ba3b811a (patch) | |
tree | 530a10faa9e8869160501a7b28195f1753c193fb | |
parent | 4d0732220ee56fd251665c43b6d43272ad5edef6 (diff) | |
parent | 4400a93b45e5c8efdb824e798d51c131eae60240 (diff) | |
download | numpy-9405d2b87e75de1d27b7866c4acf0e18ba3b811a.tar.gz |
Merge pull request #12685 from mreineck/add_pocketfft
BUG: Make pocketfft handle long doubles.
-rw-r--r-- | numpy/fft/pocketfft.c | 18 | ||||
-rw-r--r-- | numpy/fft/tests/test_pocketfft.py | 10 |
2 files changed, 23 insertions, 5 deletions
diff --git a/numpy/fft/pocketfft.c b/numpy/fft/pocketfft.c index 10a741b6f..9d1218e6b 100644 --- a/numpy/fft/pocketfft.c +++ b/numpy/fft/pocketfft.c @@ -2192,7 +2192,11 @@ WARN_UNUSED_RESULT static int rfft_forward(rfft_plan plan, double c[], double fc static PyObject * execute_complex(PyObject *a1, int is_forward, double fct) { - PyArrayObject *data = (PyArrayObject *)PyArray_CopyFromObject(a1, NPY_CDOUBLE, 1, 0); + PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1, + PyArray_DescrFromType(NPY_CDOUBLE), 1, 0, + NPY_ARRAY_ENSURECOPY | NPY_ARRAY_DEFAULT | + NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST, + NULL); if (!data) return NULL; int npts = PyArray_DIM(data, PyArray_NDIM(data) - 1); @@ -2227,8 +2231,10 @@ execute_real_forward(PyObject *a1, double fct) { rfft_plan plan=NULL; int fail = 0; - PyArrayObject *data = (PyArrayObject *)PyArray_ContiguousFromObject(a1, - NPY_DOUBLE, 1, 0); + PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1, + PyArray_DescrFromType(NPY_DOUBLE), 1, 0, + NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST, + NULL); if (!data) return NULL; int ndim = PyArray_NDIM(data); @@ -2281,8 +2287,10 @@ static PyObject * execute_real_backward(PyObject *a1, double fct) { rfft_plan plan=NULL; - PyArrayObject *data = (PyArrayObject *)PyArray_ContiguousFromObject(a1, - NPY_CDOUBLE, 1, 0); + PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1, + PyArray_DescrFromType(NPY_CDOUBLE), 1, 0, + NPY_ARRAY_DEFAULT | NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST, + NULL); if (!data) return NULL; int npts = PyArray_DIM(data, PyArray_NDIM(data) - 1); PyArrayObject *ret = (PyArrayObject *)PyArray_Empty(PyArray_NDIM(data), diff --git a/numpy/fft/tests/test_pocketfft.py b/numpy/fft/tests/test_pocketfft.py index 0552f6afd..1029294a1 100644 --- a/numpy/fft/tests/test_pocketfft.py +++ b/numpy/fft/tests/test_pocketfft.py @@ -1,6 +1,7 @@ from __future__ import division, absolute_import, print_function import numpy as np +import pytest from numpy.random import random from numpy.testing import ( assert_array_almost_equal, assert_array_equal, assert_raises, @@ -156,6 +157,15 @@ class TestFFT1D(object): assert_array_almost_equal(x_norm, np.linalg.norm(tmp)) + @pytest.mark.parametrize("dtype", [np.half, np.single, np.double, + np.longdouble]) + def test_dtypes(self, dtype): + # make sure that all input precisions are accepted and internally + # converted to 64bit + x = random(30).astype(dtype) + assert_array_almost_equal(np.fft.ifft(np.fft.fft(x)), x) + assert_array_almost_equal(np.fft.irfft(np.fft.rfft(x)), x) + class TestFFTThreadSafe(object): threads = 16 input_shape = (800, 200) |