summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-12-02 20:08:45 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2019-12-02 12:08:45 -0600
commit7b2d968d5a4730489d9e9148afe2277b1bc32477 (patch)
tree3da8067d704a681cc48dfc372e3ed855d6cb1d0b /numpy/lib
parent48481c6e2ed958cd0d30cc0e8d5611c8e68f7df3 (diff)
downloadnumpy-7b2d968d5a4730489d9e9148afe2277b1bc32477.tar.gz
BUG: warn when saving dtype with metadata (#14994)
Address gh-14142 for the 1.18 release: warn when saving a dtype with metadata that cannot be loaded.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/format.py14
-rw-r--r--numpy/lib/tests/test_format.py30
2 files changed, 44 insertions, 0 deletions
diff --git a/numpy/lib/format.py b/numpy/lib/format.py
index 1ecd72815..20e2e9c72 100644
--- a/numpy/lib/format.py
+++ b/numpy/lib/format.py
@@ -242,6 +242,16 @@ def read_magic(fp):
major, minor = magic_str[-2:]
return major, minor
+def _has_metadata(dt):
+ if dt.metadata is not None:
+ return True
+ elif dt.names is not None:
+ return any(_has_metadata(dt[k]) for k in dt.names)
+ elif dt.subdtype is not None:
+ return _has_metadata(dt.base)
+ else:
+ return False
+
def dtype_to_descr(dtype):
"""
Get a serializable descriptor from the dtype.
@@ -265,6 +275,10 @@ def dtype_to_descr(dtype):
replicate the input dtype.
"""
+ if _has_metadata(dtype):
+ warnings.warn("metadata on a dtype may be saved or ignored, but will "
+ "raise if saved when read. Use another form of storage.",
+ UserWarning, stacklevel=2)
if dtype.names is not None:
# This is a record array. The .descr is fine. XXX: parts of the
# record array with an empty name, like padding bytes, still get
diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py
index 062c21725..0592e0b12 100644
--- a/numpy/lib/tests/test_format.py
+++ b/numpy/lib/tests/test_format.py
@@ -963,3 +963,33 @@ def test_unicode_field_names():
with open(fname, 'wb') as f:
with assert_warns(UserWarning):
format.write_array(f, arr, version=None)
+
+
+@pytest.mark.parametrize('dt, fail', [
+ (np.dtype({'names': ['a', 'b'], 'formats': [float, np.dtype('S3',
+ metadata={'some': 'stuff'})]}), True),
+ (np.dtype(int, metadata={'some': 'stuff'}), False),
+ (np.dtype([('subarray', (int, (2,)))], metadata={'some': 'stuff'}), False),
+ # recursive: metadata on the field of a dtype
+ (np.dtype({'names': ['a', 'b'], 'formats': [
+ float, np.dtype({'names': ['c'], 'formats': [np.dtype(int, metadata={})]})
+ ]}), False)
+ ])
+def test_metadata_dtype(dt, fail):
+ # gh-14142
+ arr = np.ones(10, dtype=dt)
+ buf = BytesIO()
+ with assert_warns(UserWarning):
+ np.save(buf, arr)
+ buf.seek(0)
+ if fail:
+ with assert_raises(ValueError):
+ np.load(buf)
+ else:
+ arr2 = np.load(buf)
+ # BUG: assert_array_equal does not check metadata
+ from numpy.lib.format import _has_metadata
+ assert_array_equal(arr, arr2)
+ assert _has_metadata(arr.dtype)
+ assert not _has_metadata(arr2.dtype)
+