summaryrefslogtreecommitdiff
path: root/numpy/fft/tests/test_fftpack.py
diff options
context:
space:
mode:
authorGregory R. Lee <gregory.lee@cchmc.org>2017-01-03 19:22:13 -0500
committerGregory R. Lee <gregory.lee@cchmc.org>2017-01-04 18:26:13 -0500
commitbdd318bc00e1c0ba68ca81b3603f3d3cdaa0a10c (patch)
tree19c803de7da2279423cca2f8d6b1874a840eec01 /numpy/fft/tests/test_fftpack.py
parentb94c2b01ff7ef5b8dc44726512cfa232e9054882 (diff)
downloadnumpy-bdd318bc00e1c0ba68ca81b3603f3d3cdaa0a10c.tar.gz
BUG: correct norm='ortho' scaling for rfft when n != None
closes #8444
Diffstat (limited to 'numpy/fft/tests/test_fftpack.py')
-rw-r--r--numpy/fft/tests/test_fftpack.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/numpy/fft/tests/test_fftpack.py b/numpy/fft/tests/test_fftpack.py
index 2e6294252..a2cbc0f63 100644
--- a/numpy/fft/tests/test_fftpack.py
+++ b/numpy/fft/tests/test_fftpack.py
@@ -71,9 +71,13 @@ class TestFFT1D(TestCase):
def test_rfft(self):
x = random(30)
- assert_array_almost_equal(np.fft.fft(x)[:16], np.fft.rfft(x))
- assert_array_almost_equal(np.fft.rfft(x) / np.sqrt(30),
- np.fft.rfft(x, norm="ortho"))
+ for n in [x.size, 2*x.size]:
+ for norm in [None, 'ortho']:
+ assert_array_almost_equal(
+ np.fft.fft(x, n=n, norm=norm)[:(n//2 + 1)],
+ np.fft.rfft(x, n=n, norm=norm))
+ assert_array_almost_equal(np.fft.rfft(x, n=n) / np.sqrt(n),
+ np.fft.rfft(x, n=n, norm="ortho"))
def test_irfft(self):
x = random(30)
@@ -122,6 +126,24 @@ class TestFFT1D(TestCase):
x_herm, np.fft.ihfft(np.fft.hfft(x_herm, norm="ortho"),
norm="ortho"))
+ def test_all_1d_norm_preserving(self):
+ # verify that round-trip transforms are norm-preserving
+ x = random(30)
+ x_norm = np.linalg.norm(x)
+ n = x.size * 2
+ func_pairs = [(np.fft.fft, np.fft.ifft),
+ (np.fft.rfft, np.fft.irfft),
+ # hfft: order so the first function takes x.size samples
+ # (necessary for comparison to x_norm above)
+ (np.fft.ihfft, np.fft.hfft),
+ ]
+ for forw, back in func_pairs:
+ for n in [x.size, 2*x.size]:
+ for norm in [None, 'ortho']:
+ tmp = forw(x, n=n, norm=norm)
+ tmp = back(tmp, n=n, norm=norm)
+ assert_array_almost_equal(x_norm,
+ np.linalg.norm(tmp))
class TestFFTThreadSafe(TestCase):
threads = 16