summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/setup.py8
-rw-r--r--numpy/core/tests/test_ufunc.py64
2 files changed, 72 insertions, 0 deletions
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index 531df6a2e..8131317a8 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -339,6 +339,14 @@ def configuration(parent_package='',top_path=None):
extra_info = blas_info
)
+ config.add_extension('umath_tests',
+ sources = [join('src','umath_tests.c.src'),
+ ],
+ depends = [join('blasdot','cblas.h'),] + deps,
+ include_dirs = ['blasdot'],
+ extra_info = blas_info
+ )
+
config.add_data_dir('tests')
config.add_data_dir('tests/data')
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index d5e8ac76f..75bb1e2c2 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -1,5 +1,7 @@
import numpy as np
from numpy.testing import *
+from numpy.random import rand
+import numpy.core.umath_tests as umt
class TestUfunc(TestCase):
def test_reduceat_shifting_sum(self) :
@@ -229,6 +231,68 @@ class TestUfunc(TestCase):
"""
pass
+ def test_innerwt(self):
+ a = np.arange(6).reshape((2,3))
+ b = np.arange(10,16).reshape((2,3))
+ w = np.arange(20,26).reshape((2,3))
+ assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1))
+ a = np.arange(100,124).reshape((2,3,4))
+ b = np.arange(200,224).reshape((2,3,4))
+ w = np.arange(300,324).reshape((2,3,4))
+ assert_array_equal(umt.innerwt(a,b,w), np.sum(a*b*w,axis=-1))
+
+ def test_matrix_multiply(self):
+ self.compare_matrix_multiply_results(np.long)
+ self.compare_matrix_multiply_results(np.double)
+
+ def compare_matrix_multiply_results(self, tp):
+ d1 = np.array(rand(2,3,4), dtype=tp)
+ d2 = np.array(rand(2,3,4), dtype=tp)
+ msg = "matrix multiply on type %s" % d1.dtype.name
+
+ def permute_n(n):
+ if n == 1:
+ return ([0],)
+ ret = ()
+ base = permute_n(n-1)
+ for perm in base:
+ for i in xrange(n):
+ new = perm + [n-1]
+ new[n-1] = new[i]
+ new[i] = n-1
+ ret += (new,)
+ return ret
+ def slice_n(n):
+ if n == 0:
+ return ((),)
+ ret = ()
+ base = slice_n(n-1)
+ for sl in base:
+ ret += (sl+(slice(None),),)
+ ret += (sl+(slice(0,1),),)
+ return ret
+ def broadcastable(s1,s2):
+ return s1 == s2 or s1 == 1 or s2 == 1
+ permute_3 = permute_n(3)
+ slice_3 = slice_n(3) + ((slice(None,None,-1),)*3,)
+
+ ref = True
+ for p1 in permute_3:
+ for p2 in permute_3:
+ for s1 in slice_3:
+ for s2 in slice_3:
+ a1 = d1.transpose(p1)[s1]
+ a2 = d2.transpose(p2)[s2]
+ ref = ref and a1.base != None and a1.base.base != None
+ ref = ref and a2.base != None and a2.base.base != None
+ if broadcastable(a1.shape[-1], a2.shape[-2]) and \
+ broadcastable(a1.shape[0], a2.shape[0]):
+ assert_array_almost_equal(umt.matrix_multiply(a1,a2), \
+ np.sum(a2[...,np.newaxis].swapaxes(-3,-1) * \
+ a1[...,np.newaxis,:], axis=-1), \
+ err_msg = msg+' %s %s' % (str(a1.shape),str(a2.shape)))
+
+ assert_equal(ref, True, err_msg="reference check")
if __name__ == "__main__":
run_module_suite()