summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2017-05-05 17:41:21 -0600
committerGitHub <noreply@github.com>2017-05-05 17:41:21 -0600
commit857125603f3cf72e2078ed2effd0f6ce40f0ca28 (patch)
treeb768a8f6806b0178f561323847a56a86abc05d32
parentd761fd6ccbc038970798eb7dfb1a5de825653ea8 (diff)
parent6dbaf77e56d3ace8d2aeb5cb4049cccc9398705d (diff)
downloadnumpy-857125603f3cf72e2078ed2effd0f6ce40f0ca28.tar.gz
Merge pull request #8911 from eric-wieser/fix-check_api_dict
BUG: check_api_dict does not correctly handle tuple values
-rw-r--r--numpy/core/code_generators/genapi.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py
index 544597786..b618dedf5 100644
--- a/numpy/core/code_generators/genapi.py
+++ b/numpy/core/code_generators/genapi.py
@@ -426,28 +426,32 @@ def merge_api_dicts(dicts):
def check_api_dict(d):
"""Check that an api dict is valid (does not use the same index twice)."""
+ # remove the extra value fields that aren't the index
+ index_d = {k: v[0] for k, v in d.items()}
+
# We have if a same index is used twice: we 'revert' the dict so that index
# become keys. If the length is different, it means one index has been used
# at least twice
- revert_dict = dict([(v, k) for k, v in d.items()])
- if not len(revert_dict) == len(d):
+ revert_dict = {v: k for k, v in index_d.items()}
+ if not len(revert_dict) == len(index_d):
# We compute a dict index -> list of associated items
doubled = {}
- for name, index in d.items():
+ for name, index in index_d.items():
try:
doubled[index].append(name)
except KeyError:
doubled[index] = [name]
- msg = """\
-Same index has been used twice in api definition: %s
-""" % ['index %d -> %s' % (index, names) for index, names in doubled.items() \
- if len(names) != 1]
- raise ValueError(msg)
+ fmt = "Same index has been used twice in api definition: {}"
+ val = ''.join(
+ '\n\tindex {} -> {}'.format(index, names)
+ for index, names in doubled.items() if len(names) != 1
+ )
+ raise ValueError(fmt.format(val))
# No 'hole' in the indexes may be allowed, and it must starts at 0
- indexes = set(v[0] for v in d.values())
+ indexes = set(index_d.values())
expected = set(range(len(indexes)))
- if not indexes == expected:
+ if indexes != expected:
diff = expected.symmetric_difference(indexes)
msg = "There are some holes in the API indexing: " \
"(symmetric diff is %s)" % diff