summaryrefslogtreecommitdiff
path: root/numpy/matrixlib/defmatrix.py
diff options
context:
space:
mode:
authorZè Vinícius <jvmirca@gmail.com>2017-03-01 10:32:34 -0300
committerEric Wieser <wieser.eric@gmail.com>2017-03-01 13:32:34 +0000
commit35d752cbd34754667b55678f552dbdcfc5274d27 (patch)
tree87ffc78ed362740f379ed22c4f5f5658809c7594 /numpy/matrixlib/defmatrix.py
parentee3ab365cb55cce6d0b9b6ed5cfbd8e3ede8cc66 (diff)
downloadnumpy-35d752cbd34754667b55678f552dbdcfc5274d27.tar.gz
BUG: Fix creating a np.matrix from string syntax involving booleans (#8497)
Fixes #8459 * DOC: add release note [ci skip]
Diffstat (limited to 'numpy/matrixlib/defmatrix.py')
-rw-r--r--numpy/matrixlib/defmatrix.py44
1 files changed, 5 insertions, 39 deletions
diff --git a/numpy/matrixlib/defmatrix.py b/numpy/matrixlib/defmatrix.py
index bd14846c6..7026fad1a 100644
--- a/numpy/matrixlib/defmatrix.py
+++ b/numpy/matrixlib/defmatrix.py
@@ -3,49 +3,15 @@ from __future__ import division, absolute_import, print_function
__all__ = ['matrix', 'bmat', 'mat', 'asmatrix']
import sys
+import ast
import numpy.core.numeric as N
from numpy.core.numeric import concatenate, isscalar, binary_repr, identity, asanyarray
from numpy.core.numerictypes import issubdtype
-# make translation table
-_numchars = '0123456789.-+jeEL'
-
-if sys.version_info[0] >= 3:
- class _NumCharTable:
- def __getitem__(self, i):
- if chr(i) in _numchars:
- return chr(i)
- else:
- return None
- _table = _NumCharTable()
- def _eval(astr):
- str_ = astr.translate(_table)
- if not str_:
- raise TypeError("Invalid data string supplied: " + astr)
- else:
- return eval(str_)
-
-else:
- _table = [None]*256
- for k in range(256):
- _table[k] = chr(k)
- _table = ''.join(_table)
-
- _todelete = []
- for k in _table:
- if k not in _numchars:
- _todelete.append(k)
- _todelete = ''.join(_todelete)
- del k
-
- def _eval(astr):
- str_ = astr.translate(_table, _todelete)
- if not str_:
- raise TypeError("Invalid data string supplied: " + astr)
- else:
- return eval(str_)
-
def _convert_from_string(data):
+ for char in '[]':
+ data = data.replace(char, '')
+
rows = data.split(';')
newdata = []
count = 0
@@ -54,7 +20,7 @@ def _convert_from_string(data):
newrow = []
for col in trow:
temp = col.split()
- newrow.extend(map(_eval, temp))
+ newrow.extend(map(ast.literal_eval, temp))
if count == 0:
Ncols = len(newrow)
elif len(newrow) != Ncols: