summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorahaldane <ealloc@gmail.com>2016-05-12 22:06:24 -0400
committerahaldane <ealloc@gmail.com>2016-05-12 22:06:24 -0400
commitf1a36312b0996495a499ec4e5a9de14148bf957c (patch)
tree0b79ef2ae35d85b6fe08ef52b23ac08a5e1b9f02 /numpy
parent383d800042769fc5aa159ff33988a3972e6d577f (diff)
parent6aa21ad951e8334ba3d3ac677390f7afe76cd242 (diff)
downloadnumpy-f1a36312b0996495a499ec4e5a9de14148bf957c.tar.gz
Merge pull request #6872 from pec27/complex-interp
ENH: linear interpolation of complex values in lib.interp
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/compiled_base.c176
-rw-r--r--numpy/core/src/multiarray/compiled_base.h2
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c2
-rw-r--r--numpy/lib/function_base.py48
-rw-r--r--numpy/lib/tests/test_function_base.py22
5 files changed, 236 insertions, 14 deletions
diff --git a/numpy/core/src/multiarray/compiled_base.c b/numpy/core/src/multiarray/compiled_base.c
index 136d0859e..213d74f44 100644
--- a/numpy/core/src/multiarray/compiled_base.c
+++ b/numpy/core/src/multiarray/compiled_base.c
@@ -664,6 +664,182 @@ fail:
return NULL;
}
+/* As for arr_interp but for complex fp values */
+NPY_NO_EXPORT PyObject *
+arr_interp_complex(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
+{
+
+ PyObject *fp, *xp, *x;
+ PyObject *left = NULL, *right = NULL;
+ PyArrayObject *afp = NULL, *axp = NULL, *ax = NULL, *af = NULL;
+ npy_intp i, lenx, lenxp;
+
+ const npy_double *dx, *dz;
+ const npy_cdouble *dy;
+ npy_cdouble lval, rval;
+ npy_cdouble *dres, *slopes = NULL;
+
+ static char *kwlist[] = {"x", "xp", "fp", "left", "right", NULL};
+
+ NPY_BEGIN_THREADS_DEF;
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwdict, "OOO|OO", kwlist,
+ &x, &xp, &fp, &left, &right)) {
+ return NULL;
+ }
+
+ afp = (PyArrayObject *)PyArray_ContiguousFromAny(fp, NPY_CDOUBLE, 1, 1);
+
+ if (afp == NULL) {
+ return NULL;
+ }
+
+ axp = (PyArrayObject *)PyArray_ContiguousFromAny(xp, NPY_DOUBLE, 1, 1);
+ if (axp == NULL) {
+ goto fail;
+ }
+ ax = (PyArrayObject *)PyArray_ContiguousFromAny(x, NPY_DOUBLE, 1, 0);
+ if (ax == NULL) {
+ goto fail;
+ }
+ lenxp = PyArray_SIZE(axp);
+ if (lenxp == 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "array of sample points is empty");
+ goto fail;
+ }
+ if (PyArray_SIZE(afp) != lenxp) {
+ PyErr_SetString(PyExc_ValueError,
+ "fp and xp are not of the same length.");
+ goto fail;
+ }
+
+ lenx = PyArray_SIZE(ax);
+ dx = (const npy_double *)PyArray_DATA(axp);
+ dz = (const npy_double *)PyArray_DATA(ax);
+
+ af = (PyArrayObject *)PyArray_SimpleNew(PyArray_NDIM(ax),
+ PyArray_DIMS(ax), NPY_CDOUBLE);
+ if (af == NULL) {
+ goto fail;
+ }
+
+ dy = (const npy_cdouble *)PyArray_DATA(afp);
+ dres = (npy_cdouble *)PyArray_DATA(af);
+ /* Get left and right fill values. */
+ if ((left == NULL) || (left == Py_None)) {
+ lval = dy[0];
+ }
+ else {
+ lval.real = PyComplex_RealAsDouble(left);
+ if ((lval.real == -1) && PyErr_Occurred()) {
+ goto fail;
+ }
+ lval.imag = PyComplex_ImagAsDouble(left);
+ if ((lval.imag == -1) && PyErr_Occurred()) {
+ goto fail;
+ }
+ }
+
+ if ((right == NULL) || (right == Py_None)) {
+ rval = dy[lenxp - 1];
+ }
+ else {
+ rval.real = PyComplex_RealAsDouble(right);
+ if ((rval.real == -1) && PyErr_Occurred()) {
+ goto fail;
+ }
+ rval.imag = PyComplex_ImagAsDouble(right);
+ if ((rval.imag == -1) && PyErr_Occurred()) {
+ goto fail;
+ }
+ }
+
+ /* binary_search_with_guess needs at least a 3 item long array */
+ if (lenxp == 1) {
+ const npy_double xp_val = dx[0];
+ const npy_cdouble fp_val = dy[0];
+
+ NPY_BEGIN_THREADS_THRESHOLDED(lenx);
+ for (i = 0; i < lenx; ++i) {
+ const npy_double x_val = dz[i];
+ dres[i] = (x_val < xp_val) ? lval :
+ ((x_val > xp_val) ? rval : fp_val);
+ }
+ NPY_END_THREADS;
+ }
+ else {
+ npy_intp j = 0;
+
+ /* only pre-calculate slopes if there are relatively few of them. */
+ if (lenxp <= lenx) {
+ slopes = PyArray_malloc((lenxp - 1) * sizeof(npy_cdouble));
+ if (slopes == NULL) {
+ goto fail;
+ }
+ }
+
+ NPY_BEGIN_THREADS;
+
+ if (slopes != NULL) {
+ for (i = 0; i < lenxp - 1; ++i) {
+ const double inv_dx = 1.0 / (dx[i+1] - dx[i]);
+ slopes[i].real = (dy[i+1].real - dy[i].real) * inv_dx;
+ slopes[i].imag = (dy[i+1].imag - dy[i].imag) * inv_dx;
+ }
+ }
+
+ for (i = 0; i < lenx; ++i) {
+ const npy_double x_val = dz[i];
+
+ if (npy_isnan(x_val)) {
+ dres[i].real = x_val;
+ dres[i].imag = 0.0;
+ continue;
+ }
+
+ j = binary_search_with_guess(x_val, dx, lenxp, j);
+ if (j == -1) {
+ dres[i] = lval;
+ }
+ else if (j == lenxp) {
+ dres[i] = rval;
+ }
+ else if (j == lenxp - 1) {
+ dres[i] = dy[j];
+ }
+ else {
+ if (slopes!=NULL) {
+ dres[i].real = slopes[j].real*(x_val - dx[j]) + dy[j].real;
+ dres[i].imag = slopes[j].imag*(x_val - dx[j]) + dy[j].imag;
+ }
+ else {
+ const npy_double inv_dx = 1.0 / (dx[j+1] - dx[j]);
+ dres[i].real = (dy[j+1].real - dy[j].real)*(x_val - dx[j])*
+ inv_dx + dy[j].real;
+ dres[i].imag = (dy[j+1].imag - dy[j].imag)*(x_val - dx[j])*
+ inv_dx + dy[j].imag;
+ }
+ }
+ }
+
+ NPY_END_THREADS;
+ }
+ PyArray_free(slopes);
+
+ Py_DECREF(afp);
+ Py_DECREF(axp);
+ Py_DECREF(ax);
+ return (PyObject *)af;
+
+fail:
+ Py_XDECREF(afp);
+ Py_XDECREF(axp);
+ Py_XDECREF(ax);
+ Py_XDECREF(af);
+ return NULL;
+}
+
/*
* Converts a Python sequence into 'count' PyArrayObjects
*
diff --git a/numpy/core/src/multiarray/compiled_base.h b/numpy/core/src/multiarray/compiled_base.h
index 19e3778ad..51508531c 100644
--- a/numpy/core/src/multiarray/compiled_base.h
+++ b/numpy/core/src/multiarray/compiled_base.h
@@ -11,6 +11,8 @@ arr_digitize(PyObject *, PyObject *, PyObject *kwds);
NPY_NO_EXPORT PyObject *
arr_interp(PyObject *, PyObject *, PyObject *);
NPY_NO_EXPORT PyObject *
+arr_interp_complex(PyObject *, PyObject *, PyObject *);
+NPY_NO_EXPORT PyObject *
arr_ravel_multi_index(PyObject *, PyObject *, PyObject *);
NPY_NO_EXPORT PyObject *
arr_unravel_index(PyObject *, PyObject *, PyObject *);
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index e2731068b..62b562856 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4233,6 +4233,8 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"interp", (PyCFunction)arr_interp,
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"interp_complex", (PyCFunction)arr_interp_complex,
+ METH_VARARGS | METH_KEYWORDS, NULL},
{"ravel_multi_index", (PyCFunction)arr_ravel_multi_index,
METH_VARARGS | METH_KEYWORDS, NULL},
{"unravel_index", (PyCFunction)arr_unravel_index,
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 8f15fc547..3533a59fc 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -23,8 +23,10 @@ from numpy.core.fromnumeric import (
from numpy.core.numerictypes import typecodes, number
from numpy.lib.twodim_base import diag
from .utils import deprecate
-from numpy.core.multiarray import _insert, add_docstring
-from numpy.core.multiarray import digitize, bincount, interp as compiled_interp
+from numpy.core.multiarray import (
+ _insert, add_docstring, digitize, bincount,
+ interp as compiled_interp, interp_complex as compiled_interp_complex
+ )
from numpy.core.umath import _add_newdoc_ufunc as add_newdoc_ufunc
from numpy.compat import long
from numpy.compat.py3k import basestring
@@ -1663,13 +1665,13 @@ def interp(x, xp, fp, left=None, right=None, period=None):
`period` is not specified. Otherwise, `xp` is internally sorted after
normalizing the periodic boundaries with ``xp = xp % period``.
- fp : 1-D sequence of floats
+ fp : 1-D sequence of float or complex
The y-coordinates of the data points, same length as `xp`.
- left : float, optional
+ left : optional float or complex corresponding to fp
Value to return for `x < xp[0]`, default is `fp[0]`.
- right : float, optional
+ right : optional float or complex corresponding to fp
Value to return for `x > xp[-1]`, default is `fp[-1]`.
period : None or float, optional
@@ -1681,7 +1683,7 @@ def interp(x, xp, fp, left=None, right=None, period=None):
Returns
-------
- y : float or ndarray
+ y : float or complex (corresponding to fp) or ndarray
The interpolated values, same shape as `x`.
Raises
@@ -1732,14 +1734,31 @@ def interp(x, xp, fp, left=None, right=None, period=None):
>>> np.interp(x, xp, fp, period=360)
array([7.5, 5., 8.75, 6.25, 3., 3.25, 3.5, 3.75])
+ Complex interpolation
+ >>> x = [1.5, 4.0]
+ >>> xp = [2,3,5]
+ >>> fp = [1.0j, 0, 2+3j]
+ >>> np.interp(x, xp, fp)
+ array([ 0.+1.j , 1.+1.5j])
+
"""
+
+ fp = np.asarray(fp)
+
+ if np.iscomplexobj(fp):
+ interp_func = compiled_interp_complex
+ input_dtype = np.complex128
+ else:
+ interp_func = compiled_interp
+ input_dtype = np.float64
+
if period is None:
if isinstance(x, (float, int, number)):
- return compiled_interp([x], xp, fp, left, right).item()
+ return interp_func([x], xp, fp, left, right).item()
elif isinstance(x, np.ndarray) and x.ndim == 0:
- return compiled_interp([x], xp, fp, left, right).item()
+ return interp_func([x], xp, fp, left, right).item()
else:
- return compiled_interp(x, xp, fp, left, right)
+ return interp_func(x, xp, fp, left, right)
else:
if period == 0:
raise ValueError("period must be a non-zero value")
@@ -1752,7 +1771,8 @@ def interp(x, xp, fp, left=None, right=None, period=None):
x = [x]
x = np.asarray(x, dtype=np.float64)
xp = np.asarray(xp, dtype=np.float64)
- fp = np.asarray(fp, dtype=np.float64)
+ fp = np.asarray(fp, dtype=input_dtype)
+
if xp.ndim != 1 or fp.ndim != 1:
raise ValueError("Data points must be 1-D sequences")
if xp.shape[0] != fp.shape[0]:
@@ -1765,12 +1785,12 @@ def interp(x, xp, fp, left=None, right=None, period=None):
fp = fp[asort_xp]
xp = np.concatenate((xp[-1:]-period, xp, xp[0:1]+period))
fp = np.concatenate((fp[-1:], fp, fp[0:1]))
+
if return_array:
- return compiled_interp(x, xp, fp, left, right)
+ return interp_func(x, xp, fp, left, right)
else:
- return compiled_interp(x, xp, fp, left, right).item()
-
-
+ return interp_func(x, xp, fp, left, right).item()
+
def angle(z, deg=0):
"""
Return the angle of the complex argument.
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 044279294..0f71393ad 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2235,6 +2235,28 @@ class TestInterp(TestCase):
assert_almost_equal(np.interp(x0, x, y), x0)
x0 = np.nan
assert_almost_equal(np.interp(x0, x, y), x0)
+
+ def test_complex_interp(self):
+ # test complex interpolation
+ x = np.linspace(0, 1, 5)
+ y = np.linspace(0, 1, 5) + (1 + np.linspace(0, 1, 5))*1.0j
+ x0 = 0.3
+ y0 = x0 + (1+x0)*1.0j
+ assert_almost_equal(np.interp(x0, x, y), y0)
+ # test complex left and right
+ x0 = -1
+ left = 2 + 3.0j
+ assert_almost_equal(np.interp(x0, x, y, left=left), left)
+ x0 = 2.0
+ right = 2 + 3.0j
+ assert_almost_equal(np.interp(x0, x, y, right=right), right)
+ # test complex periodic
+ x = [-180, -170, -185, 185, -10, -5, 0, 365]
+ xp = [190, -190, 350, -350]
+ fp = [5+1.0j, 10+2j, 3+3j, 4+4j]
+ y = [7.5+1.5j, 5.+1.0j, 8.75+1.75j, 6.25+1.25j, 3.+3j, 3.25+3.25j,
+ 3.5+3.5j, 3.75+3.75j]
+ assert_almost_equal(np.interp(x, xp, fp, period=360), y)
def test_zero_dimensional_interpolation_point(self):
x = np.linspace(0, 1, 5)