summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/npyio.py42
-rw-r--r--numpy/lib/tests/test_io.py33
2 files changed, 65 insertions, 10 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py
index 3e3d00c57..3a575a048 100644
--- a/numpy/lib/npyio.py
+++ b/numpy/lib/npyio.py
@@ -880,7 +880,15 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
fmt : str or sequence of strs, optional
A single format (%10.5f), a sequence of formats, or a
multi-format string, e.g. 'Iteration %d -- %10.5f', in which
- case `delimiter` is ignored.
+ case `delimiter` is ignored. For complex `X`, the legal options
+ for `fmt` are:
+ a) a single specifier, `fmt='%.4e'`, resulting in numbers formatted
+ like `' (%s+%sj)' % (fmt, fmt)`
+ b) a full string specifying every real and imaginary part, e.g.
+ `' %.4e %+.4j %.4e %+.4j %.4e %+.4j'` for 3 columns
+ c) a list of specifiers, one per column - in this case, the real
+ and imaginary part must have separate specifiers,
+ e.g. `['%.3e + %.3ej', '(%.15e%+.15ej)']` for 2 columns
delimiter : str, optional
Character separating columns.
newline : str, optional
@@ -895,11 +903,10 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
String that will be prepended to the ``header`` and ``footer`` strings,
to mark them as comments. Default: '# ', as expected by e.g.
``numpy.loadtxt``.
- .. versionadded:: 2.0.0
+ .. versionadded:: 1.7.0
Character separating lines.
-
See Also
--------
save : Save an array to a binary file in NumPy ``.npy`` format
@@ -1003,6 +1010,7 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
else:
ncol = X.shape[1]
+ iscomplex_X = np.iscomplexobj(X)
# `fmt` can be a string with multiple insertion points or a
# list of formats. E.g. '%10.5f\t%10d' or ('%10.5f', '$10d')
if type(fmt) in (list, tuple):
@@ -1010,20 +1018,34 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
raise AttributeError('fmt has wrong shape. %s' % str(fmt))
format = asstr(delimiter).join(map(asstr, fmt))
elif type(fmt) is str:
- if fmt.count('%') == 1:
- fmt = [fmt, ]*ncol
+ n_fmt_chars = fmt.count('%')
+ error = ValueError('fmt has wrong number of %% formats: %s' % fmt)
+ if n_fmt_chars == 1:
+ if iscomplex_X:
+ fmt = [' (%s+%sj)' % (fmt, fmt),] * ncol
+ else:
+ fmt = [fmt, ] * ncol
format = delimiter.join(fmt)
- elif fmt.count('%') != ncol:
- raise AttributeError('fmt has wrong number of %% formats. %s'
- % fmt)
+ elif iscomplex_X and n_fmt_chars != (2 * ncol):
+ raise error
+ elif ((not iscomplex_X) and n_fmt_chars != ncol):
+ raise error
else:
format = fmt
if len(header) > 0:
header = header.replace('\n', '\n' + comments)
fh.write(asbytes(comments + header + newline))
- for row in X:
- fh.write(asbytes(format % tuple(row) + newline))
+ if iscomplex_X:
+ for row in X:
+ row2 = []
+ for number in row:
+ row2.append(number.real)
+ row2.append(number.imag)
+ fh.write(asbytes(format % tuple(row2) + newline))
+ else:
+ for row in X:
+ fh.write(asbytes(format % tuple(row) + newline))
if len(footer) > 0:
footer = footer.replace('\n', '\n' + comments)
fh.write(asbytes(comments + footer + newline))
diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py
index 54c2ddabb..356534a25 100644
--- a/numpy/lib/tests/test_io.py
+++ b/numpy/lib/tests/test_io.py
@@ -273,6 +273,39 @@ class TestSaveTxt(TestCase):
finally:
os.unlink(name)
+ def test_complex_arrays(self):
+ ncols = 2
+ nrows = 2
+ a = np.zeros((ncols, nrows), dtype=np.complex128)
+ re = np.pi
+ im = np.e
+ a[:] = re + 1.0j * im
+ # One format only
+ c = StringIO()
+ np.savetxt(c, a, fmt=' %+.3e')
+ c.seek(0)
+ lines = c.readlines()
+ assert_equal(lines, asbytes_nested([
+ ' ( +3.142e+00+ +2.718e+00j) ( +3.142e+00+ +2.718e+00j)\n',
+ ' ( +3.142e+00+ +2.718e+00j) ( +3.142e+00+ +2.718e+00j)\n']))
+ # One format for each real and imaginary part
+ c = StringIO()
+ np.savetxt(c, a, fmt=' %+.3e' * 2 * ncols)
+ c.seek(0)
+ lines = c.readlines()
+ assert_equal(lines, asbytes_nested([
+ ' +3.142e+00 +2.718e+00 +3.142e+00 +2.718e+00\n',
+ ' +3.142e+00 +2.718e+00 +3.142e+00 +2.718e+00\n']))
+ # One format for each complex number
+ c = StringIO()
+ np.savetxt(c, a, fmt=['(%.3e%+.3ej)'] * ncols)
+ c.seek(0)
+ lines = c.readlines()
+ assert_equal(lines, asbytes_nested([
+ '(3.142e+00+2.718e+00j) (3.142e+00+2.718e+00j)\n',
+ '(3.142e+00+2.718e+00j) (3.142e+00+2.718e+00j)\n']))
+
+
class TestLoadTxt(TestCase):
def test_record(self):