diff options
author | Ben Root <ben.v.root@gmail.com> | 2011-09-27 09:52:38 -0500 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2012-01-27 20:43:32 -0700 |
commit | 06d947c51726c303cc5c30f16643903d89da7207 (patch) | |
tree | 455809fe664f6c386c00152dfc52be58cc8ea2f6 /numpy | |
parent | 7bb277bacb92fcbb1ab2980234fe033dcc70d628 (diff) | |
download | numpy-06d947c51726c303cc5c30f16643903d89da7207.tar.gz |
ENH: Support datetime64, timedelta64 in gradient. Allow array-like input.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/function_base.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index f254bbacf..4ab1679e5 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -867,6 +867,7 @@ def gradient(f, *varargs): [ 1. , 1. , 1. ]])] """ + f = np.asanyarray(f) N = len(f.shape) # number of dimensions n = len(varargs) if n == 0: @@ -889,12 +890,20 @@ def gradient(f, *varargs): slice3 = [slice(None)]*N otype = f.dtype.char - if otype not in ['f', 'd', 'F', 'D']: + if otype not in ['f', 'd', 'F', 'D', 'm', 'M']: otype = 'd' + # Difference of datetime64 elements results in timedelta64 + if otype == 'M' : + # Need to use the full dtype name because it contains unit information + otype = f.dtype.name.replace('datetime', 'timedelta') + elif otype == 'm' : + # Needs to keep the specific units, can't be a general unit + otype = f.dtype + for axis in range(N): # select out appropriate parts for this dimension - out = np.zeros_like(f).astype(otype) + out = np.empty_like(f, dtype=otype) slice1[axis] = slice(1, -1) slice2[axis] = slice(2, None) slice3[axis] = slice(None, -2) |