summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py4
-rw-r--r--numpy/core/tests/test_multiarray.py34
2 files changed, 37 insertions, 1 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 12e690f1e..5d7407ce0 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -133,7 +133,9 @@ def zeros_like(a, dtype=None, order='K', subok=True):
"""
res = empty_like(a, dtype=dtype, order=order, subok=subok)
- multiarray.copyto(res, 0, casting='unsafe')
+ # needed instead of a 0 to get same result as zeros for for string dtypes
+ z = zeros(1, dtype=res.dtype)
+ multiarray.copyto(res, z, casting='unsafe')
return res
def ones(shape, dtype=None, order='C'):
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index a2667b38a..70398ee84 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -552,6 +552,40 @@ class TestCreation(TestCase):
d = zeros(10, dtype=[('k', object, 2)])
assert_array_equal(d['k'], 0)
+ def test_zeros_like_like_zeros(self):
+ # test zeros_like returns the same as zeros
+ for c in np.typecodes['All']:
+ if c == 'V':
+ continue
+ d = zeros((3,3), dtype=c)
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+ # explicitly check some special cases
+ d = zeros((3,3), dtype='S5')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+ d = zeros((3,3), dtype='U5')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+
+ d = zeros((3,3), dtype='<i4')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+ d = zeros((3,3), dtype='>i4')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+
+ d = zeros((3,3), dtype='<M8[s]')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+ d = zeros((3,3), dtype='>M8[s]')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+
+ d = zeros((3,3), dtype='f4,f4')
+ assert_array_equal(zeros_like(d), d)
+ assert_equal(zeros_like(d).dtype, d.dtype)
+
def test_sequence_non_homogenous(self):
assert_equal(np.array([4, 2**80]).dtype, np.object)
assert_equal(np.array([4, 2**80, 4]).dtype, np.object)