From 26efcc65cd55ab975c91edc5c1466735c99b4807 Mon Sep 17 00:00:00 2001 From: Stefan van der Walt Date: Thu, 21 Aug 2008 19:45:02 +0000 Subject: Add tests [patch by Wenjie Fu and Hans-Andreas Engel]. --- numpy/core/setup.py | 8 ++++++ numpy/core/tests/test_ufunc.py | 64 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) (limited to 'numpy') diff --git a/numpy/core/setup.py b/numpy/core/setup.py index 531df6a2e..8131317a8 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -339,6 +339,14 @@ def configuration(parent_package='',top_path=None): extra_info = blas_info ) + config.add_extension('umath_tests', + sources = [join('src','umath_tests.c.src'), + ], + depends = [join('blasdot','cblas.h'),] + deps, + include_dirs = ['blasdot'], + extra_info = blas_info + ) + config.add_data_dir('tests') config.add_data_dir('tests/data') diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index d5e8ac76f..75bb1e2c2 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -1,5 +1,7 @@ import numpy as np from numpy.testing import * +from numpy.random import rand +import numpy.core.umath_tests as umt class TestUfunc(TestCase): def test_reduceat_shifting_sum(self) : @@ -229,6 +231,68 @@ class TestUfunc(TestCase): """ pass + def test_innerwt(self): + a = np.arange(6).reshape((2,3)) + b = np.arange(10,16).reshape((2,3)) + w = np.arange(20,26).reshape((2,3)) + assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1)) + a = np.arange(100,124).reshape((2,3,4)) + b = np.arange(200,224).reshape((2,3,4)) + w = np.arange(300,324).reshape((2,3,4)) + assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1)) + + def test_matrix_multiply(self): + self.compare_matrix_multiply_results(np.long) + self.compare_matrix_multiply_results(np.double) + + def compare_matrix_multiply_results(self, tp): + d1 = np.array(rand(2,3,4), dtype=tp) + d2 = np.array(rand(2,3,4), dtype=tp) + msg = "matrix multiply on type %s" % d1.dtype.name + + def permute_n(n): + if n == 1: + return ([0],) + ret = () + base = permute_n(n-1) + for perm in base: + for i in xrange(n): + new = perm + [n-1] + new[n-1] = new[i] + new[i] = n-1 + ret += (new,) + return ret + def slice_n(n): + if n == 0: + return ((),) + ret = () + base = slice_n(n-1) + for sl in base: + ret += (sl+(slice(None),),) + ret += (sl+(slice(0,1),),) + return ret + def broadcastable(s1,s2): + return s1 == s2 or s1 == 1 or s2 == 1 + permute_3 = permute_n(3) + slice_3 = slice_n(3) + ((slice(None,None,-1),)*3,) + + ref = True + for p1 in permute_3: + for p2 in permute_3: + for s1 in slice_3: + for s2 in slice_3: + a1 = d1.transpose(p1)[s1] + a2 = d2.transpose(p2)[s2] + ref = ref and a1.base != None and a1.base.base != None + ref = ref and a2.base != None and a2.base.base != None + if broadcastable(a1.shape[-1], a2.shape[-2]) and \ + broadcastable(a1.shape[0], a2.shape[0]): + assert_array_almost_equal(umt.matrix_multiply(a1,a2), \ + np.sum(a2[...,np.newaxis].swapaxes(-3,-1) * \ + a1[...,np.newaxis,:], axis=-1), \ + err_msg = msg+' %s %s' % (str(a1.shape),str(a2.shape))) + + assert_equal(ref, True, err_msg="reference check") if __name__ == "__main__": run_module_suite() -- cgit v1.2.1