diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2019-08-08 13:37:37 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-08-08 13:37:37 -0500 |
commit | a754567bf3bd02ea45fed55178d843657fa35be5 (patch) | |
tree | 20398331cdaac42e201e1c7a753255a67d2a4012 /numpy | |
parent | 5ed0ebe7ab6cfc8a0be36d40e3876d1c02946240 (diff) | |
parent | abf3fe8e3814c41c9a79eac87f5b137a617d1184 (diff) | |
download | numpy-a754567bf3bd02ea45fed55178d843657fa35be5.tar.gz |
Merge pull request #14234 from larsoner/fix-norm
MAINT: Better error message for norm
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/fft/pocketfft.py | 5 | ||||
-rw-r--r-- | numpy/fft/tests/test_pocketfft.py | 10 |
2 files changed, 10 insertions, 5 deletions
diff --git a/numpy/fft/pocketfft.py b/numpy/fft/pocketfft.py index b7f6f1434..77ea6e3ba 100644 --- a/numpy/fft/pocketfft.py +++ b/numpy/fft/pocketfft.py @@ -271,9 +271,10 @@ def ifft(a, n=None, axis=-1, norm=None): a = asarray(a) if n is None: n = a.shape[axis] - fct = 1/n if norm is not None and _unitary(norm): - fct = 1/sqrt(n) + fct = 1/sqrt(max(n, 1)) + else: + fct = 1/max(n, 1) output = _raw_fft(a, n, axis, False, False, fct) return output diff --git a/numpy/fft/tests/test_pocketfft.py b/numpy/fft/tests/test_pocketfft.py index 7a06d4c45..453e964fa 100644 --- a/numpy/fft/tests/test_pocketfft.py +++ b/numpy/fft/tests/test_pocketfft.py @@ -45,12 +45,16 @@ class TestFFT1D(object): assert_allclose(fft1(x) / np.sqrt(30), np.fft.fft(x, norm="ortho"), atol=1e-6) - def test_ifft(self): + @pytest.mark.parametrize('norm', (None, 'ortho')) + def test_ifft(self, norm): x = random(30) + 1j*random(30) - assert_allclose(x, np.fft.ifft(np.fft.fft(x)), atol=1e-6) assert_allclose( - x, np.fft.ifft(np.fft.fft(x, norm="ortho"), norm="ortho"), + x, np.fft.ifft(np.fft.fft(x, norm=norm), norm=norm), atol=1e-6) + # Ensure we get the correct error message + with pytest.raises(ValueError, + match='Invalid number of FFT data points'): + np.fft.ifft([], norm=norm) def test_fft2(self): x = random((30, 20)) + 1j*random((30, 20)) |