summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-02-17 19:11:18 -0500
committerCharles Harris <charlesr.harris@gmail.com>2015-02-17 19:11:18 -0500
commit4338d11c104d867cf88d5ba2a87b032e58de7053 (patch)
tree225ff3efa45d4fd6cf97fb827964ba21363a21ea
parent4065adbcc4e9d320f41b78922a906f16c6add7cc (diff)
parent97c481ee90459e8e372b74144f666cf06ad9df61 (diff)
downloadnumpy-4338d11c104d867cf88d5ba2a87b032e58de7053.tar.gz
Merge pull request #5577 from charris/cleanup-gh-5263
BUG: financial.pmt warns of zero divide when rate == 0.
-rw-r--r--numpy/lib/financial.py10
-rw-r--r--numpy/lib/tests/test_financial.py41
2 files changed, 27 insertions, 24 deletions
diff --git a/numpy/lib/financial.py b/numpy/lib/financial.py
index baff8b0b6..a7e4e60b6 100644
--- a/numpy/lib/financial.py
+++ b/numpy/lib/financial.py
@@ -208,11 +208,11 @@ def pmt(rate, nper, pv, fv=0, when='end'):
"""
when = _convert_when(when)
(rate, nper, pv, fv, when) = map(np.asarray, [rate, nper, pv, fv, when])
- temp = (1+rate)**nper
- miter = np.broadcast(rate, nper, pv, fv, when)
- zer = np.zeros(miter.shape)
- fact = np.where(rate == zer, nper + zer,
- (1 + rate*when)*(temp - 1)/rate + zer)
+ temp = (1 + rate)**nper
+ mask = (rate == 0.0)
+ np.copyto(rate, 1.0, where=mask)
+ z = np.zeros(np.broadcast(rate, nper, pv, fv, when).shape)
+ fact = np.where(mask != z, nper + z, (1 + rate*when)*(temp - 1)/rate + z)
return -(fv + pv*temp) / fact
def nper(rate, pmt, pv, fv=0, when='end'):
diff --git a/numpy/lib/tests/test_financial.py b/numpy/lib/tests/test_financial.py
index a4b9cfe2e..baa785424 100644
--- a/numpy/lib/tests/test_financial.py
+++ b/numpy/lib/tests/test_financial.py
@@ -2,7 +2,8 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
- run_module_suite, TestCase, assert_, assert_almost_equal
+ run_module_suite, TestCase, assert_, assert_almost_equal,
+ assert_allclose
)
@@ -13,35 +14,37 @@ class TestFinancial(TestCase):
def test_irr(self):
v = [-150000, 15000, 25000, 35000, 45000, 60000]
- assert_almost_equal(np.irr(v),
- 0.0524, 2)
+ assert_almost_equal(np.irr(v), 0.0524, 2)
v = [-100, 0, 0, 74]
- assert_almost_equal(np.irr(v),
- -0.0955, 2)
+ assert_almost_equal(np.irr(v), -0.0955, 2)
v = [-100, 39, 59, 55, 20]
- assert_almost_equal(np.irr(v),
- 0.28095, 2)
+ assert_almost_equal(np.irr(v), 0.28095, 2)
v = [-100, 100, 0, -7]
- assert_almost_equal(np.irr(v),
- -0.0833, 2)
+ assert_almost_equal(np.irr(v), -0.0833, 2)
v = [-100, 100, 0, 7]
- assert_almost_equal(np.irr(v),
- 0.06206, 2)
+ assert_almost_equal(np.irr(v), 0.06206, 2)
v = [-5, 10.5, 1, -8, 1]
- assert_almost_equal(np.irr(v),
- 0.0886, 2)
+ assert_almost_equal(np.irr(v), 0.0886, 2)
def test_pv(self):
- assert_almost_equal(np.pv(0.07, 20, 12000, 0),
- -127128.17, 2)
+ assert_almost_equal(np.pv(0.07, 20, 12000, 0), -127128.17, 2)
def test_fv(self):
- assert_almost_equal(np.fv(0.075, 20, -2000, 0, 0),
- 86609.36, 2)
+ assert_almost_equal(np.fv(0.075, 20, -2000, 0, 0), 86609.36, 2)
def test_pmt(self):
- assert_almost_equal(np.pmt(0.08/12, 5*12, 15000),
- -304.146, 3)
+ res = np.pmt(0.08/12, 5*12, 15000)
+ tgt = -304.145914
+ assert_allclose(res, tgt)
+ # Test the edge case where rate == 0.0
+ res = np.pmt(0.0, 5*12, 15000)
+ tgt = -250.0
+ assert_allclose(res, tgt)
+ # Test the case where we use broadcast and
+ # the arguments passed in are arrays.
+ res = np.pmt([[0.0, 0.8],[0.3, 0.8]],[12, 3],[2000, 20000])
+ tgt = np.array([[-166.66667, -19311.258],[-626.90814, -19311.258]])
+ assert_allclose(res, tgt)
def test_ppmt(self):
np.round(np.ppmt(0.1/12, 1, 60, 55000), 2) == 710.25