diff options
author | ovillellas <oscar.villellas@continuum.io> | 2013-03-22 17:46:51 +0100 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-10 22:47:44 +0300 |
commit | 5dc27acdafa572e12d693ca926498c7a5681f548 (patch) | |
tree | b609d302563d603ace8c5eb00753920c17241710 | |
parent | 60f54b6c1585d022428b3bb57f638536cd1e639e (diff) | |
download | numpy-5dc27acdafa572e12d693ca926498c7a5681f548.tar.gz |
FIX: matrix_multiply now works when given a column matrix
-rw-r--r-- | numpy/core/src/umath/umath_linalg.c.src | 116 | ||||
-rw-r--r-- | numpy/core/tests/test_gufuncs_linalg.py | 5 |
2 files changed, 82 insertions, 39 deletions
diff --git a/numpy/core/src/umath/umath_linalg.c.src b/numpy/core/src/umath/umath_linalg.c.src index b83fad1aa..a63bcce9a 100644 --- a/numpy/core/src/umath/umath_linalg.c.src +++ b/numpy/core/src/umath/umath_linalg.c.src @@ -1112,63 +1112,101 @@ dump_gemm_params(const char* name, const GEMM_PARAMS_t* params) "TRANSA", ' ', params->TRANSA, "TRANSB", ' ', params->TRANSB); } + +typedef struct _matrix_desc { + int need_buff; + fortran_int lead; + size_t size; + void *buff; +} matrix_desc; + +static inline void +matrix_desc_init(matrix_desc *mtxd, + npy_intp* steps, + size_t sot, + fortran_int rows, + fortran_int columns) +/* initialized a matrix desc based on the information in steps + it will fill the matrix desc with the values needed. + steps[0] contains column step + steps[1] contains row step +*/ +{ + /* if column step is not element step, a copy will be needed + to arrange the array in a way compatible with blas + */ + mtxd->need_buff = steps[0] != sot; + + if (!mtxd->need_buff) { + if (steps[1]) { + mtxd->lead = (fortran_int) steps[1] / sot; + } else { + /* step is 0, this means either it is only 1 column or + there is a step trick around*/ + if (columns > 1) { + /* step tricks not supported in gemm... make a copy */ + mtxd->need_buff = 1; + } else { + /* lead must be at least 1 */ + mtxd->lead = rows; + } + } + } + + /* if a copy is performed, the lead is the number of columns */ + if (mtxd->need_buff) { + mtxd->lead = rows; + } + + mtxd->size = rows*columns*sot*mtxd->need_buff; + mtxd->buff = NULL; +} + +static inline npy_int8 * +matrix_desc_assign_buff(matrix_desc* mtxd, npy_int8 *p) +{ + if (mtxd->need_buff) { + mtxd->buff = p; + return p + mtxd->size; + } else { + return p; + } +} + static inline int init_gemm_params(GEMM_PARAMS_t *params, fortran_int m, fortran_int k, fortran_int n, npy_intp* steps, - size_t size_of_type) + size_t sot) { + matrix_desc a, b, c; + matrix_desc_init(&a, steps + 0, sot, m, k); + matrix_desc_init(&b, steps + 2, sot, k, n); + matrix_desc_init(&c, steps + 4, sot, m, n); npy_uint8 *mem_buff = NULL; - int needs_a, needs_b, needs_c; - size_t a_size, b_size, c_size; - void *a, *b, *c; - - needs_a = steps[0] != size_of_type; - needs_b = steps[2] != size_of_type; - needs_c = steps[4] != size_of_type; - a_size = m*k*size_of_type*needs_a; - b_size = k*n*size_of_type*needs_b; - c_size = m*n*size_of_type*needs_c; - - if (a_size + b_size + c_size) + if (a.size + b.size + c.size) { npy_uint8 *current; /* need to allocate some buffer */ - mem_buff = malloc(a_size + b_size + c_size); + mem_buff = malloc(a.size + b.size + c.size); if (!mem_buff) goto error; current = mem_buff; - if (needs_a) { - a = current; - current += a_size; - } else { - a = NULL; - } - - if (needs_b) { - b = current; - current += b_size; - } else { - b = NULL; - } - - if (needs_c) { - c = current; - } else { - c = NULL; - } + current = matrix_desc_assign_buff(&a, current); + current = matrix_desc_assign_buff(&b, current); + current = matrix_desc_assign_buff(&c, current); } - params->A = a; - params->B = b; - params->C = c; - params->LDA = (fortran_int)(needs_a ? m : steps[1]/size_of_type); - params->LDB = (fortran_int)(needs_b ? k : steps[3]/size_of_type); - params->LDC = (fortran_int)(needs_c ? m : steps[5]/size_of_type); + params->A = a.buff; + params->B = b.buff; + params->C = c.buff; + params->LDA = a.lead; + params->LDB = b.lead; + params->LDC = c.lead; params->M = m; params->N = n; params->K = k; diff --git a/numpy/core/tests/test_gufuncs_linalg.py b/numpy/core/tests/test_gufuncs_linalg.py index 513ac9dcd..34af58a7c 100644 --- a/numpy/core/tests/test_gufuncs_linalg.py +++ b/numpy/core/tests/test_gufuncs_linalg.py @@ -271,6 +271,11 @@ class TestMatrixMultiply(GeneralTestCase): else: assert_almost_equal(res[0], np.dot(a[0],b[0])) + def test_column_matrix(self): + A = np.arange(2*2).reshape((2,2)) + B = np.arange(2*1).reshape((2,1)) + res = gula.matrix_multiply(A,B) + assert_almost_equal(res, np.dot(A,B)) class TestInv(GeneralTestCase, TestCase): def do(self, a, b): |