diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 42 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 33 |
2 files changed, 65 insertions, 10 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 3e3d00c57..3a575a048 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -880,7 +880,15 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', fmt : str or sequence of strs, optional A single format (%10.5f), a sequence of formats, or a multi-format string, e.g. 'Iteration %d -- %10.5f', in which - case `delimiter` is ignored. + case `delimiter` is ignored. For complex `X`, the legal options + for `fmt` are: + a) a single specifier, `fmt='%.4e'`, resulting in numbers formatted + like `' (%s+%sj)' % (fmt, fmt)` + b) a full string specifying every real and imaginary part, e.g. + `' %.4e %+.4j %.4e %+.4j %.4e %+.4j'` for 3 columns + c) a list of specifiers, one per column - in this case, the real + and imaginary part must have separate specifiers, + e.g. `['%.3e + %.3ej', '(%.15e%+.15ej)']` for 2 columns delimiter : str, optional Character separating columns. newline : str, optional @@ -895,11 +903,10 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', String that will be prepended to the ``header`` and ``footer`` strings, to mark them as comments. Default: '# ', as expected by e.g. ``numpy.loadtxt``. - .. versionadded:: 2.0.0 + .. versionadded:: 1.7.0 Character separating lines. - See Also -------- save : Save an array to a binary file in NumPy ``.npy`` format @@ -1003,6 +1010,7 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', else: ncol = X.shape[1] + iscomplex_X = np.iscomplexobj(X) # `fmt` can be a string with multiple insertion points or a # list of formats. E.g. '%10.5f\t%10d' or ('%10.5f', '$10d') if type(fmt) in (list, tuple): @@ -1010,20 +1018,34 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', raise AttributeError('fmt has wrong shape. %s' % str(fmt)) format = asstr(delimiter).join(map(asstr, fmt)) elif type(fmt) is str: - if fmt.count('%') == 1: - fmt = [fmt, ]*ncol + n_fmt_chars = fmt.count('%') + error = ValueError('fmt has wrong number of %% formats: %s' % fmt) + if n_fmt_chars == 1: + if iscomplex_X: + fmt = [' (%s+%sj)' % (fmt, fmt),] * ncol + else: + fmt = [fmt, ] * ncol format = delimiter.join(fmt) - elif fmt.count('%') != ncol: - raise AttributeError('fmt has wrong number of %% formats. %s' - % fmt) + elif iscomplex_X and n_fmt_chars != (2 * ncol): + raise error + elif ((not iscomplex_X) and n_fmt_chars != ncol): + raise error else: format = fmt if len(header) > 0: header = header.replace('\n', '\n' + comments) fh.write(asbytes(comments + header + newline)) - for row in X: - fh.write(asbytes(format % tuple(row) + newline)) + if iscomplex_X: + for row in X: + row2 = [] + for number in row: + row2.append(number.real) + row2.append(number.imag) + fh.write(asbytes(format % tuple(row2) + newline)) + else: + for row in X: + fh.write(asbytes(format % tuple(row) + newline)) if len(footer) > 0: footer = footer.replace('\n', '\n' + comments) fh.write(asbytes(comments + footer + newline)) diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 54c2ddabb..356534a25 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -273,6 +273,39 @@ class TestSaveTxt(TestCase): finally: os.unlink(name) + def test_complex_arrays(self): + ncols = 2 + nrows = 2 + a = np.zeros((ncols, nrows), dtype=np.complex128) + re = np.pi + im = np.e + a[:] = re + 1.0j * im + # One format only + c = StringIO() + np.savetxt(c, a, fmt=' %+.3e') + c.seek(0) + lines = c.readlines() + assert_equal(lines, asbytes_nested([ + ' ( +3.142e+00+ +2.718e+00j) ( +3.142e+00+ +2.718e+00j)\n', + ' ( +3.142e+00+ +2.718e+00j) ( +3.142e+00+ +2.718e+00j)\n'])) + # One format for each real and imaginary part + c = StringIO() + np.savetxt(c, a, fmt=' %+.3e' * 2 * ncols) + c.seek(0) + lines = c.readlines() + assert_equal(lines, asbytes_nested([ + ' +3.142e+00 +2.718e+00 +3.142e+00 +2.718e+00\n', + ' +3.142e+00 +2.718e+00 +3.142e+00 +2.718e+00\n'])) + # One format for each complex number + c = StringIO() + np.savetxt(c, a, fmt=['(%.3e%+.3ej)'] * ncols) + c.seek(0) + lines = c.readlines() + assert_equal(lines, asbytes_nested([ + '(3.142e+00+2.718e+00j) (3.142e+00+2.718e+00j)\n', + '(3.142e+00+2.718e+00j) (3.142e+00+2.718e+00j)\n'])) + + class TestLoadTxt(TestCase): def test_record(self): |