diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-05-05 17:41:21 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-05 17:41:21 -0600 |
commit | 857125603f3cf72e2078ed2effd0f6ce40f0ca28 (patch) | |
tree | b768a8f6806b0178f561323847a56a86abc05d32 | |
parent | d761fd6ccbc038970798eb7dfb1a5de825653ea8 (diff) | |
parent | 6dbaf77e56d3ace8d2aeb5cb4049cccc9398705d (diff) | |
download | numpy-857125603f3cf72e2078ed2effd0f6ce40f0ca28.tar.gz |
Merge pull request #8911 from eric-wieser/fix-check_api_dict
BUG: check_api_dict does not correctly handle tuple values
-rw-r--r-- | numpy/core/code_generators/genapi.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py index 544597786..b618dedf5 100644 --- a/numpy/core/code_generators/genapi.py +++ b/numpy/core/code_generators/genapi.py @@ -426,28 +426,32 @@ def merge_api_dicts(dicts): def check_api_dict(d): """Check that an api dict is valid (does not use the same index twice).""" + # remove the extra value fields that aren't the index + index_d = {k: v[0] for k, v in d.items()} + # We have if a same index is used twice: we 'revert' the dict so that index # become keys. If the length is different, it means one index has been used # at least twice - revert_dict = dict([(v, k) for k, v in d.items()]) - if not len(revert_dict) == len(d): + revert_dict = {v: k for k, v in index_d.items()} + if not len(revert_dict) == len(index_d): # We compute a dict index -> list of associated items doubled = {} - for name, index in d.items(): + for name, index in index_d.items(): try: doubled[index].append(name) except KeyError: doubled[index] = [name] - msg = """\ -Same index has been used twice in api definition: %s -""" % ['index %d -> %s' % (index, names) for index, names in doubled.items() \ - if len(names) != 1] - raise ValueError(msg) + fmt = "Same index has been used twice in api definition: {}" + val = ''.join( + '\n\tindex {} -> {}'.format(index, names) + for index, names in doubled.items() if len(names) != 1 + ) + raise ValueError(fmt.format(val)) # No 'hole' in the indexes may be allowed, and it must starts at 0 - indexes = set(v[0] for v in d.values()) + indexes = set(index_d.values()) expected = set(range(len(indexes))) - if not indexes == expected: + if indexes != expected: diff = expected.symmetric_difference(indexes) msg = "There are some holes in the API indexing: " \ "(symmetric diff is %s)" % diff |