diff options
-rw-r--r-- | numpy/add_newdocs.py | 19 | ||||
-rw-r--r-- | numpy/core/src/arraymethods.c | 59 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 12 |
3 files changed, 76 insertions, 14 deletions
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 17d21cd7c..3529938a5 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -2198,14 +2198,27 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('var', add_newdoc('numpy.core.multiarray', 'ndarray', ('view', - """a.view(dtype=None) + """a.view(dtype=None, type=None) New view of array with the same data. Parameters ---------- - dtype : sub-type or data-descriptor - Data-type of the returned view. + dtype : data-type + Data-type descriptor of the returned view, e.g. float32 or int16. + type : python type + Type of the returned view, e.g. ndarray or matrix. + + Examples + -------- + >>> x = np.array([(1,2)],dtype=[('a',np.int8),('b',np.int8)]) + >>> y = x.view(dtype=np.int16, type=np.matrix) + + >>> print y.dtype + int16 + + >>> print type(y) + <class 'numpy.core.defmatrix.matrix'> """)) diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c index cfd912644..1af6a02b7 100644 --- a/numpy/core/src/arraymethods.c +++ b/numpy/core/src/arraymethods.c @@ -103,26 +103,63 @@ array_squeeze(PyArrayObject *self, PyObject *args) static PyObject * array_view(PyArrayObject *self, PyObject *args, PyObject *kwds) { - PyObject *otype=NULL; - PyArray_Descr *type=NULL; + PyObject *out_dtype_or_type=NULL; + PyObject *out_dtype=NULL; + PyObject *out_type=NULL; + PyArray_Descr *dtype=NULL; - static char *kwlist[] = {"dtype", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, &otype)) + static char *kwlist[] = {"dtype_or_type", "dtype", "type", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO", kwlist, + &out_dtype_or_type, + &out_dtype, + &out_type)) return NULL; - if (otype) { - if (PyType_Check(otype) && \ - PyType_IsSubtype((PyTypeObject *)otype, + /* If user specified a positional argument, guess whether it + represents a type or a dtype for backward compatibility. */ + if (out_dtype_or_type) { + + /* type specified? */ + if (PyType_Check(out_dtype_or_type) && + PyType_IsSubtype((PyTypeObject *)out_dtype_or_type, &PyArray_Type)) { - return PyArray_View(self, NULL, - (PyTypeObject *)otype); + if (out_type) { + PyErr_SetString(PyExc_ValueError, + "Cannot specify output type twice."); + return NULL; + } + + out_type = out_dtype_or_type; } + + /* dtype specified */ else { - if (PyArray_DescrConverter(otype, &type) == PY_FAIL) + if (out_dtype) { + PyErr_SetString(PyExc_ValueError, + "Cannot specify output dtype twice."); return NULL; + } + + out_dtype = out_dtype_or_type; } } - return PyArray_View(self, type, NULL); + + if ((out_type) && (!PyType_Check(out_type) || + !PyType_IsSubtype((PyTypeObject *)out_type, + &PyArray_Type))) { + PyErr_SetString(PyExc_ValueError, + "Type must be a Python type object"); + return NULL; + } + + if ((out_dtype) && + (PyArray_DescrConverter(out_dtype, &dtype) == PY_FAIL)) { + PyErr_SetString(PyExc_ValueError, + "Dtype must be a numpy data-type"); + return NULL; + } + + return PyArray_View(self, dtype, (PyTypeObject*)out_type); } static PyObject * diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 9a7f8c9ff..ae6c43f10 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -842,6 +842,18 @@ class TestView(NumpyTestCase): assert_array_equal(y,z) assert_array_equal(y, [67305985, 134678021]) + def test_type(self): + x = np.array([1,2,3]) + assert(isinstance(x.view(np.matrix),np.matrix)) + + def test_keywords(self): + x = np.array([(1,2)],dtype=[('a',np.int8),('b',np.int8)]) + y = x.view(dtype=np.int16, type=np.matrix) + assert_array_equal(y,[[513]]) + + assert(isinstance(y,np.matrix)) + assert_equal(y.dtype,np.int16) + # Import tests without matching module names set_local_path() |