summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py43
-rw-r--r--numpy/core/tests/test_numeric.py74
2 files changed, 99 insertions, 18 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index c1c555172..430f7a715 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -597,7 +597,9 @@ def require(a, dtype=None, requirements=None):
a : array_like
The object to be converted to a type-and-requirement-satisfying array.
dtype : data-type
- The required data-type, the default data-type is float64).
+ The required data-type. If None preserve the current dtype. If your
+ application requires the data to be in native byteorder, include
+ a byteorder specification as a part of the dtype specification.
requirements : str or list of str
The requirements list can be any of the following
@@ -606,6 +608,7 @@ def require(a, dtype=None, requirements=None):
* 'ALIGNED' ('A') - ensure a data-type aligned array
* 'WRITEABLE' ('W') - ensure a writable array
* 'OWNDATA' ('O') - ensure an array that owns its own data
+ * 'ENSUREARRAY', ('E') - ensure a base array, instead of a subclass
See Also
--------
@@ -642,34 +645,38 @@ def require(a, dtype=None, requirements=None):
UPDATEIFCOPY : False
"""
- if requirements is None:
- requirements = []
- else:
- requirements = [x.upper() for x in requirements]
-
+ possible_flags = {'C':'C', 'C_CONTIGUOUS':'C', 'CONTIGUOUS':'C',
+ 'F':'F', 'F_CONTIGUOUS':'F', 'FORTRAN':'F',
+ 'A':'A', 'ALIGNED':'A',
+ 'W':'W', 'WRITEABLE':'W',
+ 'O':'O', 'OWNDATA':'O',
+ 'E':'E', 'ENSUREARRAY':'E'}
if not requirements:
return asanyarray(a, dtype=dtype)
+ else:
+ requirements = set(possible_flags[x.upper()] for x in requirements)
- if 'ENSUREARRAY' in requirements or 'E' in requirements:
+ if 'E' in requirements:
+ requirements.remove('E')
subok = False
else:
subok = True
- arr = array(a, dtype=dtype, copy=False, subok=subok)
+ order = 'A'
+ if requirements >= set(['C', 'F']):
+ raise ValueError('Cannot specify both "C" and "F" order')
+ elif 'F' in requirements:
+ order = 'F'
+ requirements.remove('F')
+ elif 'C' in requirements:
+ order = 'C'
+ requirements.remove('C')
- copychar = 'A'
- if 'FORTRAN' in requirements or \
- 'F_CONTIGUOUS' in requirements or \
- 'F' in requirements:
- copychar = 'F'
- elif 'CONTIGUOUS' in requirements or \
- 'C_CONTIGUOUS' in requirements or \
- 'C' in requirements:
- copychar = 'C'
+ arr = array(a, dtype=dtype, order=order, copy=False, subok=subok)
for prop in requirements:
if not arr.flags[prop]:
- arr = arr.copy(copychar)
+ arr = arr.copy(order)
break
return arr
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index d8b01a532..b151e24f3 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2148,5 +2148,79 @@ def test_outer_out_param():
assert_equal(res1, out1)
assert_equal(np.outer(arr2, arr3, out2), out2)
+class TestRequire(object):
+ flag_names = ['C', 'C_CONTIGUOUS', 'CONTIGUOUS',
+ 'F', 'F_CONTIGUOUS', 'FORTRAN',
+ 'A', 'ALIGNED',
+ 'W', 'WRITEABLE',
+ 'O', 'OWNDATA']
+
+ def generate_all_false(self, dtype):
+ arr = np.zeros((2, 2), [('junk', 'i1'), ('a', dtype)])
+ arr.setflags(write=False)
+ a = arr['a']
+ assert_(not a.flags['C'])
+ assert_(not a.flags['F'])
+ assert_(not a.flags['O'])
+ assert_(not a.flags['W'])
+ assert_(not a.flags['A'])
+ return a
+
+ def set_and_check_flag(self, flag, dtype, arr):
+ if dtype is None:
+ dtype = arr.dtype
+ b = np.require(arr, dtype, [flag])
+ assert_(b.flags[flag])
+ assert_(b.dtype == dtype)
+
+ # a further call to np.require ought to return the same array
+ # unless OWNDATA is specified.
+ c = np.require(b, None, [flag])
+ if flag[0] != 'O':
+ assert_(c is b)
+ else:
+ assert_(c.flags[flag])
+
+ def test_require_each(self):
+
+ id = ['f8', 'i4']
+ fd = [None, 'f8', 'c16']
+ for idtype, fdtype, flag in itertools.product(id, fd, self.flag_names):
+ a = self.generate_all_false(idtype)
+ yield self.set_and_check_flag, flag, fdtype, a
+
+ def test_unknown_requirement(self):
+ a = self.generate_all_false('f8')
+ assert_raises(KeyError, np.require, a, None, 'Q')
+
+ def test_non_array_input(self):
+ a = np.require([1, 2, 3, 4], 'i4', ['C', 'A', 'O'])
+ assert_(a.flags['O'])
+ assert_(a.flags['C'])
+ assert_(a.flags['A'])
+ assert_(a.dtype == 'i4')
+ assert_equal(a, [1, 2, 3, 4])
+
+ def test_C_and_F_simul(self):
+ a = self.generate_all_false('f8')
+ assert_raises(ValueError, np.require, a, None, ['C', 'F'])
+
+ def test_ensure_array(self):
+ class ArraySubclass(ndarray):
+ pass
+
+ a = ArraySubclass((2,2))
+ b = np.require(a, None, ['E'])
+ assert_(type(b) is np.ndarray)
+
+ def test_preserve_subtype(self):
+ class ArraySubclass(ndarray):
+ pass
+
+ for flag in self.flag_names:
+ a = ArraySubclass((2,2))
+ yield self.set_and_check_flag, flag, None, a
+
+
if __name__ == "__main__":
run_module_suite()