diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-12-02 20:08:45 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2019-12-02 12:08:45 -0600 |
commit | 7b2d968d5a4730489d9e9148afe2277b1bc32477 (patch) | |
tree | 3da8067d704a681cc48dfc372e3ed855d6cb1d0b /numpy/lib | |
parent | 48481c6e2ed958cd0d30cc0e8d5611c8e68f7df3 (diff) | |
download | numpy-7b2d968d5a4730489d9e9148afe2277b1bc32477.tar.gz |
BUG: warn when saving dtype with metadata (#14994)
Address gh-14142 for the 1.18 release: warn when saving a dtype with metadata that cannot be loaded.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/format.py | 14 | ||||
-rw-r--r-- | numpy/lib/tests/test_format.py | 30 |
2 files changed, 44 insertions, 0 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 1ecd72815..20e2e9c72 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -242,6 +242,16 @@ def read_magic(fp): major, minor = magic_str[-2:] return major, minor +def _has_metadata(dt): + if dt.metadata is not None: + return True + elif dt.names is not None: + return any(_has_metadata(dt[k]) for k in dt.names) + elif dt.subdtype is not None: + return _has_metadata(dt.base) + else: + return False + def dtype_to_descr(dtype): """ Get a serializable descriptor from the dtype. @@ -265,6 +275,10 @@ def dtype_to_descr(dtype): replicate the input dtype. """ + if _has_metadata(dtype): + warnings.warn("metadata on a dtype may be saved or ignored, but will " + "raise if saved when read. Use another form of storage.", + UserWarning, stacklevel=2) if dtype.names is not None: # This is a record array. The .descr is fine. XXX: parts of the # record array with an empty name, like padding bytes, still get diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index 062c21725..0592e0b12 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -963,3 +963,33 @@ def test_unicode_field_names(): with open(fname, 'wb') as f: with assert_warns(UserWarning): format.write_array(f, arr, version=None) + + +@pytest.mark.parametrize('dt, fail', [ + (np.dtype({'names': ['a', 'b'], 'formats': [float, np.dtype('S3', + metadata={'some': 'stuff'})]}), True), + (np.dtype(int, metadata={'some': 'stuff'}), False), + (np.dtype([('subarray', (int, (2,)))], metadata={'some': 'stuff'}), False), + # recursive: metadata on the field of a dtype + (np.dtype({'names': ['a', 'b'], 'formats': [ + float, np.dtype({'names': ['c'], 'formats': [np.dtype(int, metadata={})]}) + ]}), False) + ]) +def test_metadata_dtype(dt, fail): + # gh-14142 + arr = np.ones(10, dtype=dt) + buf = BytesIO() + with assert_warns(UserWarning): + np.save(buf, arr) + buf.seek(0) + if fail: + with assert_raises(ValueError): + np.load(buf) + else: + arr2 = np.load(buf) + # BUG: assert_array_equal does not check metadata + from numpy.lib.format import _has_metadata + assert_array_equal(arr, arr2) + assert _has_metadata(arr.dtype) + assert not _has_metadata(arr2.dtype) + |