summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorovillellas <oscar.villellas@continuum.io>2013-03-22 17:46:51 +0100
committerPauli Virtanen <pav@iki.fi>2013-04-10 22:47:44 +0300
commit5dc27acdafa572e12d693ca926498c7a5681f548 (patch)
treeb609d302563d603ace8c5eb00753920c17241710
parent60f54b6c1585d022428b3bb57f638536cd1e639e (diff)
downloadnumpy-5dc27acdafa572e12d693ca926498c7a5681f548.tar.gz
FIX: matrix_multiply now works when given a column matrix
-rw-r--r--numpy/core/src/umath/umath_linalg.c.src116
-rw-r--r--numpy/core/tests/test_gufuncs_linalg.py5
2 files changed, 82 insertions, 39 deletions
diff --git a/numpy/core/src/umath/umath_linalg.c.src b/numpy/core/src/umath/umath_linalg.c.src
index b83fad1aa..a63bcce9a 100644
--- a/numpy/core/src/umath/umath_linalg.c.src
+++ b/numpy/core/src/umath/umath_linalg.c.src
@@ -1112,63 +1112,101 @@ dump_gemm_params(const char* name, const GEMM_PARAMS_t* params)
"TRANSA", ' ', params->TRANSA, "TRANSB", ' ', params->TRANSB);
}
+
+typedef struct _matrix_desc {
+ int need_buff;
+ fortran_int lead;
+ size_t size;
+ void *buff;
+} matrix_desc;
+
+static inline void
+matrix_desc_init(matrix_desc *mtxd,
+ npy_intp* steps,
+ size_t sot,
+ fortran_int rows,
+ fortran_int columns)
+/* initialized a matrix desc based on the information in steps
+ it will fill the matrix desc with the values needed.
+ steps[0] contains column step
+ steps[1] contains row step
+*/
+{
+ /* if column step is not element step, a copy will be needed
+ to arrange the array in a way compatible with blas
+ */
+ mtxd->need_buff = steps[0] != sot;
+
+ if (!mtxd->need_buff) {
+ if (steps[1]) {
+ mtxd->lead = (fortran_int) steps[1] / sot;
+ } else {
+ /* step is 0, this means either it is only 1 column or
+ there is a step trick around*/
+ if (columns > 1) {
+ /* step tricks not supported in gemm... make a copy */
+ mtxd->need_buff = 1;
+ } else {
+ /* lead must be at least 1 */
+ mtxd->lead = rows;
+ }
+ }
+ }
+
+ /* if a copy is performed, the lead is the number of columns */
+ if (mtxd->need_buff) {
+ mtxd->lead = rows;
+ }
+
+ mtxd->size = rows*columns*sot*mtxd->need_buff;
+ mtxd->buff = NULL;
+}
+
+static inline npy_int8 *
+matrix_desc_assign_buff(matrix_desc* mtxd, npy_int8 *p)
+{
+ if (mtxd->need_buff) {
+ mtxd->buff = p;
+ return p + mtxd->size;
+ } else {
+ return p;
+ }
+}
+
static inline int
init_gemm_params(GEMM_PARAMS_t *params,
fortran_int m,
fortran_int k,
fortran_int n,
npy_intp* steps,
- size_t size_of_type)
+ size_t sot)
{
+ matrix_desc a, b, c;
+ matrix_desc_init(&a, steps + 0, sot, m, k);
+ matrix_desc_init(&b, steps + 2, sot, k, n);
+ matrix_desc_init(&c, steps + 4, sot, m, n);
npy_uint8 *mem_buff = NULL;
- int needs_a, needs_b, needs_c;
- size_t a_size, b_size, c_size;
- void *a, *b, *c;
-
- needs_a = steps[0] != size_of_type;
- needs_b = steps[2] != size_of_type;
- needs_c = steps[4] != size_of_type;
- a_size = m*k*size_of_type*needs_a;
- b_size = k*n*size_of_type*needs_b;
- c_size = m*n*size_of_type*needs_c;
-
- if (a_size + b_size + c_size)
+ if (a.size + b.size + c.size)
{
npy_uint8 *current;
/* need to allocate some buffer */
- mem_buff = malloc(a_size + b_size + c_size);
+ mem_buff = malloc(a.size + b.size + c.size);
if (!mem_buff)
goto error;
current = mem_buff;
- if (needs_a) {
- a = current;
- current += a_size;
- } else {
- a = NULL;
- }
-
- if (needs_b) {
- b = current;
- current += b_size;
- } else {
- b = NULL;
- }
-
- if (needs_c) {
- c = current;
- } else {
- c = NULL;
- }
+ current = matrix_desc_assign_buff(&a, current);
+ current = matrix_desc_assign_buff(&b, current);
+ current = matrix_desc_assign_buff(&c, current);
}
- params->A = a;
- params->B = b;
- params->C = c;
- params->LDA = (fortran_int)(needs_a ? m : steps[1]/size_of_type);
- params->LDB = (fortran_int)(needs_b ? k : steps[3]/size_of_type);
- params->LDC = (fortran_int)(needs_c ? m : steps[5]/size_of_type);
+ params->A = a.buff;
+ params->B = b.buff;
+ params->C = c.buff;
+ params->LDA = a.lead;
+ params->LDB = b.lead;
+ params->LDC = c.lead;
params->M = m;
params->N = n;
params->K = k;
diff --git a/numpy/core/tests/test_gufuncs_linalg.py b/numpy/core/tests/test_gufuncs_linalg.py
index 513ac9dcd..34af58a7c 100644
--- a/numpy/core/tests/test_gufuncs_linalg.py
+++ b/numpy/core/tests/test_gufuncs_linalg.py
@@ -271,6 +271,11 @@ class TestMatrixMultiply(GeneralTestCase):
else:
assert_almost_equal(res[0], np.dot(a[0],b[0]))
+ def test_column_matrix(self):
+ A = np.arange(2*2).reshape((2,2))
+ B = np.arange(2*1).reshape((2,1))
+ res = gula.matrix_multiply(A,B)
+ assert_almost_equal(res, np.dot(A,B))
class TestInv(GeneralTestCase, TestCase):
def do(self, a, b):