summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorMax Kellermeier <max.kellermeier@hotmail.de>2020-08-20 14:34:22 +0200
committerMax Kellermeier <max.kellermeier@hotmail.de>2020-08-20 14:34:22 +0200
commitc3ea9b680e458a91e0a7e97ce61d6853c379385f (patch)
tree54ed40358f5ae044f1b5e3e6860338bd55bdc945 /numpy/lib/function_base.py
parent11ae4340173c644768368755bba93ced112b4505 (diff)
downloadnumpy-c3ea9b680e458a91e0a7e97ce61d6853c379385f.tar.gz
Integer input returning integer output
Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index ceea92849..491b25915 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1550,13 +1550,21 @@ def unwrap(p, discont=None, axis=-1, *, period=2*pi):
slice1 = [slice(None, None)]*nd # full slices
slice1[axis] = slice(1, None)
slice1 = tuple(slice1)
- ddmod = mod(dd + period/2, period) - period/2
- # for `mask = (abs(dd) == period/2)`, the above line made `ddmod[mask] == -period/2`.
- # correct these such that `ddmod[mask] == sign(dd[mask])*period/2`.
- _nx.copyto(ddmod, period/2, where=(ddmod == -period/2) & (dd > 0))
+ dtype = np.result_type(dd, period)
+ if _nx.issubdtype(dtype, _nx.integer):
+ interval_low = -(period // 2)
+ interval_high = -interval_low
+ else:
+ interval_low = -period / 2
+ interval_high = -interval_low
+ ddmod = mod(dd - interval_low, period) + interval_low
+ if period % 2 == 0:
+ # for `mask = (abs(dd) == period/2)`, the above line made `ddmod[mask] == -period/2`.
+ # correct these such that `ddmod[mask] == sign(dd[mask])*period/2`.
+ _nx.copyto(ddmod, interval_high, where=(ddmod == interval_low) & (dd > 0))
ph_correct = ddmod - dd
_nx.copyto(ph_correct, 0, where=abs(dd) < discont)
- up = array(p, copy=True, dtype='d')
+ up = array(p, copy=True, dtype=dtype)
up[slice1] = p[slice1] + ph_correct.cumsum(axis)
return up