summaryrefslogtreecommitdiff
path: root/numpy/polynomial/hermite.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2011-12-20 18:00:39 -0700
committerCharles Harris <charlesr.harris@gmail.com>2012-01-09 10:45:13 -0700
commit266915b2d5480474f9eeb1cb7a11e6753e4fcf2e (patch)
tree050cc866d7f69b283e678f25f1d9a95eebd5e6ca /numpy/polynomial/hermite.py
parent26e2ae4b8b55bc82cd46d5d309ac5eced4bc8fe4 (diff)
downloadnumpy-266915b2d5480474f9eeb1cb7a11e6753e4fcf2e.tar.gz
BUG: Small fixes and additions
Where xxx is one of poly, cheb, leg, lag, herm, herme: Refactor xxxval2d, xxxval3d, xxxgrid2d, and xxxgrid3d for clarity. Check that coordinate arrays are compatible in xxxval2d, xxxval3d. Work around einsum bug that affected xxxvander3d.
Diffstat (limited to 'numpy/polynomial/hermite.py')
-rw-r--r--numpy/polynomial/hermite.py42
1 files changed, 31 insertions, 11 deletions
diff --git a/numpy/polynomial/hermite.py b/numpy/polynomial/hermite.py
index 340e6009b..fadd866ee 100644
--- a/numpy/polynomial/hermite.py
+++ b/numpy/polynomial/hermite.py
@@ -965,7 +965,14 @@ def hermval2d(x, y, c):
hermval, hermgrid2d, hermval3d, hermgrid3d
"""
- return hermval(y, hermval(x, c), False)
+ try:
+ x, y = np.array((x, y), copy=0)
+ except:
+ raise ValueError('x, y are incompatible')
+
+ c = hermval(x, c)
+ c = hermval(y, c, tensor=False)
+ return c
def hermgrid2d(x, y, c):
@@ -1005,7 +1012,9 @@ def hermgrid2d(x, y, c):
hermval, hermval2d, hermval3d, hermgrid3d
"""
- return hermval(y, hermval(x, c))
+ c = hermval(x, c)
+ c = hermval(y, c)
+ return c
def hermval3d(x, y, z, c):
@@ -1042,7 +1051,15 @@ def hermval3d(x, y, z, c):
hermval, hermval2d, hermgrid2d, hermgrid3d
"""
- return hermval(z, hermval2d(x, y, c), False)
+ try:
+ x, y, z = np.array((x, y, z), copy=0)
+ except:
+ raise ValueError('x, y, z are incompatible')
+
+ c = hermval(x, c)
+ c = hermval(y, c, tensor=False)
+ c = hermval(z, c, tensor=False)
+ return c
def hermgrid3d(x, y, z, c):
@@ -1084,7 +1101,10 @@ def hermgrid3d(x, y, z, c):
hermval, hermval2d, hermgrid2d, hermval3d
"""
- return hermval(z, hermgrid2d(x, y, c))
+ c = hermval(x, c)
+ c = hermval(y, c)
+ c = hermval(z, c)
+ return c
def hermvander(x, deg) :
@@ -1174,12 +1194,12 @@ def hermvander2d(x, y, deg) :
is_valid = [id == d and id >= 0 for id, d in zip(ideg, deg)]
if is_valid != [1, 1]:
raise ValueError("degrees must be non-negative integers")
- degx, degy = deg
+ degx, degy = ideg
x, y = np.array((x, y), copy=0) + 0.0
vx = hermvander(x, degx)
vy = hermvander(y, degy)
- v = np.einsum("...i,...j->...ij", vx, vy)
+ v = vx[..., None]*vy[..., None, :]
return v.reshape(v.shape[:-2] + (-1,))
@@ -1220,13 +1240,13 @@ def hermvander3d(x, y, z, deg) :
is_valid = [id == d and id >= 0 for id, d in zip(ideg, deg)]
if is_valid != [1, 1, 1]:
raise ValueError("degrees must be non-negative integers")
- degx, degy, degz = deg
+ degx, degy, degz = ideg
x, y, z = np.array((x, y, z), copy=0) + 0.0
- vx = hermvander(x, deg_x)
- vy = hermvander(y, deg_y)
- vz = hermvander(z, deg_z)
- v = np.einsum("...i,...j,...k->...ijk", vx, vy, vz)
+ vx = hermvander(x, degx)
+ vy = hermvander(y, degy)
+ vz = hermvander(z, degz)
+ v = vx[..., None, None]*vy[..., None, :, None]*vz[..., None, None, :]
return v.reshape(v.shape[:-3] + (-1,))