diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/numeric.py | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 34 |
2 files changed, 37 insertions, 1 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 12e690f1e..5d7407ce0 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -133,7 +133,9 @@ def zeros_like(a, dtype=None, order='K', subok=True): """ res = empty_like(a, dtype=dtype, order=order, subok=subok) - multiarray.copyto(res, 0, casting='unsafe') + # needed instead of a 0 to get same result as zeros for for string dtypes + z = zeros(1, dtype=res.dtype) + multiarray.copyto(res, z, casting='unsafe') return res def ones(shape, dtype=None, order='C'): diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index a2667b38a..70398ee84 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -552,6 +552,40 @@ class TestCreation(TestCase): d = zeros(10, dtype=[('k', object, 2)]) assert_array_equal(d['k'], 0) + def test_zeros_like_like_zeros(self): + # test zeros_like returns the same as zeros + for c in np.typecodes['All']: + if c == 'V': + continue + d = zeros((3,3), dtype=c) + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + # explicitly check some special cases + d = zeros((3,3), dtype='S5') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + d = zeros((3,3), dtype='U5') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + + d = zeros((3,3), dtype='<i4') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + d = zeros((3,3), dtype='>i4') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + + d = zeros((3,3), dtype='<M8[s]') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + d = zeros((3,3), dtype='>M8[s]') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + + d = zeros((3,3), dtype='f4,f4') + assert_array_equal(zeros_like(d), d) + assert_equal(zeros_like(d).dtype, d.dtype) + def test_sequence_non_homogenous(self): assert_equal(np.array([4, 2**80]).dtype, np.object) assert_equal(np.array([4, 2**80, 4]).dtype, np.object) |