summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/einsumfunc.py300
-rw-r--r--numpy/core/tests/test_einsum.py57
2 files changed, 265 insertions, 92 deletions
diff --git a/numpy/core/einsumfunc.py b/numpy/core/einsumfunc.py
index 2fac3caf3..163f125c2 100644
--- a/numpy/core/einsumfunc.py
+++ b/numpy/core/einsumfunc.py
@@ -4,6 +4,8 @@ Implementation of optimized einsum.
"""
from __future__ import division, absolute_import, print_function
+import itertools
+
from numpy.compat import basestring
from numpy.core.multiarray import c_einsum
from numpy.core.numeric import asarray, asanyarray, result_type, tensordot, dot
@@ -14,6 +16,44 @@ einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
einsum_symbols_set = set(einsum_symbols)
+def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
+ """
+ Computes the number of FLOPS in the contraction.
+
+ Parameters
+ ----------
+ idx_contraction : iterable
+ The indices involved in the contraction
+ inner : bool
+ Does this contraction require an inner product?
+ num_terms : int
+ The number of terms in a contraction
+ size_dictionary : dict
+ The size of each of the indices in idx_contraction
+
+ Returns
+ -------
+ flop_count : int
+ The total number of FLOPS required for the contraction.
+
+ Examples
+ --------
+
+ >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
+ 90
+
+ >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
+ 270
+
+ """
+
+ overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
+ op_factor = max(1, num_terms - 1)
+ if inner:
+ op_factor += 1
+
+ return overall_size * op_factor
+
def _compute_size_by_dict(indices, idx_dict):
"""
Computes the product of the elements in indices based on the dictionary
@@ -139,14 +179,9 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
iter_results = []
# Compute all unique pairs
- comb_iter = []
- for x in range(len(input_sets) - iteration):
- for y in range(x + 1, len(input_sets) - iteration):
- comb_iter.append((x, y))
-
for curr in full_results:
cost, positions, remaining = curr
- for con in comb_iter:
+ for con in itertools.combinations(range(len(input_sets) - iteration), 2):
# Find the contraction
cont = _find_contraction(con, remaining, output_set)
@@ -157,15 +192,10 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
if new_size > memory_limit:
continue
- # Find cost
- new_cost = _compute_size_by_dict(idx_contract, idx_dict)
- if idx_removed:
- new_cost *= 2
-
# Build (total_cost, positions, indices_remaining)
- new_cost += cost
+ total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
new_pos = positions + [con]
- iter_results.append((new_cost, new_pos, new_input_sets))
+ iter_results.append((total_cost, new_pos, new_input_sets))
# Update combinatorial list, if we did not find anything return best
# path + remaining contractions
@@ -183,6 +213,102 @@ def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
path = min(full_results, key=lambda x: x[0])[1]
return path
+def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
+ """Compute the cost (removed size + flops) and resultant indices for
+ performing the contraction specified by ``positions``.
+
+ Parameters
+ ----------
+ positions : tuple of int
+ The locations of the proposed tensors to contract.
+ input_sets : list of sets
+ The indices found on each tensors.
+ output_set : set
+ The output indices of the expression.
+ idx_dict : dict
+ Mapping of each index to its size.
+ memory_limit : int
+ The total allowed size for an intermediary tensor.
+ path_cost : int
+ The contraction cost so far.
+ naive_cost : int
+ The cost of the unoptimized expression.
+
+ Returns
+ -------
+ cost : (int, int)
+ A tuple containing the size of any indices removed, and the flop cost.
+ positions : tuple of int
+ The locations of the proposed tensors to contract.
+ new_input_sets : list of sets
+ The resulting new list of indices if this proposed contraction is performed.
+
+ """
+
+ # Find the contraction
+ contract = _find_contraction(positions, input_sets, output_set)
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
+
+ # Sieve the results based on memory_limit
+ new_size = _compute_size_by_dict(idx_result, idx_dict)
+ if new_size > memory_limit:
+ return None
+
+ # Build sort tuple
+ old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
+ removed_size = sum(old_sizes) - new_size
+
+ # NB: removed_size used to be just the size of any removed indices i.e.:
+ # helpers.compute_size_by_dict(idx_removed, idx_dict)
+ cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
+ sort = (-removed_size, cost)
+
+ # Sieve based on total cost as well
+ if (path_cost + cost) > naive_cost:
+ return None
+
+ # Add contraction to possible choices
+ return [sort, positions, new_input_sets]
+
+
+def _update_other_results(results, best):
+ """Update the positions and provisional input_sets of ``results`` based on
+ performing the contraction result ``best``. Remove any involving the tensors
+ contracted.
+
+ Parameters
+ ----------
+ results : list
+ List of contraction results produced by ``_parse_possible_contraction``.
+ best : list
+ The best contraction of ``results`` i.e. the one that will be performed.
+
+ Returns
+ -------
+ mod_results : list
+ The list of modifed results, updated with outcome of ``best`` contraction.
+ """
+
+ best_con = best[1]
+ bx, by = best_con
+ mod_results = []
+
+ for cost, (x, y), con_sets in results:
+
+ # Ignore results involving tensors just contracted
+ if x in best_con or y in best_con:
+ continue
+
+ # Update the input_sets
+ del con_sets[by - int(by > x) - int(by > y)]
+ del con_sets[bx - int(bx > x) - int(bx > y)]
+ con_sets.insert(-1, best[2][-1])
+
+ # Update the position indices
+ mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
+ mod_results.append((cost, mod_con, con_sets))
+
+ return mod_results
def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
"""
@@ -219,46 +345,68 @@ def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
[(0, 2), (0, 1)]
"""
+ # Handle trivial cases that leaked through
if len(input_sets) == 1:
return [(0,)]
+ elif len(input_sets) == 2:
+ return [(0, 1)]
+
+ # Build up a naive cost
+ contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
+ idx_result, new_input_sets, idx_removed, idx_contract = contract
+ naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
+ # Initially iterate over all pairs
+ comb_iter = itertools.combinations(range(len(input_sets)), 2)
+ known_contractions = []
+
+ path_cost = 0
path = []
- for iteration in range(len(input_sets) - 1):
- iteration_results = []
- comb_iter = []
- # Compute all unique pairs
- for x in range(len(input_sets)):
- for y in range(x + 1, len(input_sets)):
- comb_iter.append((x, y))
+ for iteration in range(len(input_sets) - 1):
+ # Iterate over all pairs on first step, only previously found pairs on subsequent steps
for positions in comb_iter:
- # Find the contraction
- contract = _find_contraction(positions, input_sets, output_set)
- idx_result, new_input_sets, idx_removed, idx_contract = contract
-
- # Sieve the results based on memory_limit
- if _compute_size_by_dict(idx_result, idx_dict) > memory_limit:
+ # Always initially ignore outer products
+ if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
continue
- # Build sort tuple
- removed_size = _compute_size_by_dict(idx_removed, idx_dict)
- cost = _compute_size_by_dict(idx_contract, idx_dict)
- sort = (-removed_size, cost)
+ result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
+ naive_cost)
+ if result is not None:
+ known_contractions.append(result)
- # Add contraction to possible choices
- iteration_results.append([sort, positions, new_input_sets])
+ # If we do not have a inner contraction, rescan pairs including outer products
+ if len(known_contractions) == 0:
- # If we did not find a new contraction contract remaining
- if len(iteration_results) == 0:
- path.append(tuple(range(len(input_sets))))
- break
+ # Then check the outer products
+ for positions in itertools.combinations(range(len(input_sets)), 2):
+ result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
+ path_cost, naive_cost)
+ if result is not None:
+ known_contractions.append(result)
+
+ # If we still did not find any remaining contractions, default back to einsum like behavior
+ if len(known_contractions) == 0:
+ path.append(tuple(range(len(input_sets))))
+ break
# Sort based on first index
- best = min(iteration_results, key=lambda x: x[0])
- path.append(best[1])
+ best = min(known_contractions, key=lambda x: x[0])
+
+ # Now propagate as many unused contractions as possible to next iteration
+ known_contractions = _update_other_results(known_contractions, best)
+
+ # Next iteration only compute contractions with the new tensor
+ # All other contractions have been accounted for
input_sets = best[2]
+ new_tensor_pos = len(input_sets) - 1
+ comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
+
+ # Update path and total cost
+ path.append(best[1])
+ path_cost += best[0][1]
return path
@@ -314,26 +462,27 @@ def _can_dot(inputs, result, idx_removed):
if len(inputs) != 2:
return False
- # Build a few temporaries
input_left, input_right = inputs
+
+ for c in set(input_left + input_right):
+ # can't deal with repeated indices on same input or more than 2 total
+ nl, nr = input_left.count(c), input_right.count(c)
+ if (nl > 1) or (nr > 1) or (nl + nr > 2):
+ return False
+
+ # can't do implicit summation or dimension collapse e.g.
+ # "ab,bc->c" (implicitly sum over 'a')
+ # "ab,ca->ca" (take diagonal of 'a')
+ if nl + nr - 1 == int(c in result):
+ return False
+
+ # Build a few temporaries
set_left = set(input_left)
set_right = set(input_right)
keep_left = set_left - idx_removed
keep_right = set_right - idx_removed
rs = len(idx_removed)
- # Indices must overlap between the two operands
- if not len(set_left & set_right):
- return False
-
- # We cannot have duplicate indices ("ijj, jk -> ik")
- if (len(set_left) != len(input_left)) or (len(set_right) != len(input_right)):
- return False
-
- # Cannot handle partial inner ("ij, ji -> i")
- if len(keep_left & keep_right):
- return False
-
# At this point we are a DOT, GEMV, or GEMM operation
# Handle inner products
@@ -698,6 +847,7 @@ def einsum_path(*operands, **kwargs):
# Get length of each unique dimension and ensure all dimensions are correct
dimension_dict = {}
+ broadcast_indices = [[] for x in range(len(input_list))]
for tnum, term in enumerate(input_list):
sh = operands[tnum].shape
if len(sh) != len(term):
@@ -706,6 +856,11 @@ def einsum_path(*operands, **kwargs):
% (input_subscripts[tnum], tnum))
for cnum, char in enumerate(term):
dim = sh[cnum]
+
+ # Build out broadcast indices
+ if dim == 1:
+ broadcast_indices[tnum].append(char)
+
if char in dimension_dict.keys():
# For broadcasting cases we always want the largest dim size
if dimension_dict[char] == 1:
@@ -717,6 +872,9 @@ def einsum_path(*operands, **kwargs):
else:
dimension_dict[char] = dim
+ # Convert broadcast inds to sets
+ broadcast_indices = [set(x) for x in broadcast_indices]
+
# Compute size of each input array plus the output array
size_list = []
for term in input_list + [output_subscript]:
@@ -730,20 +888,14 @@ def einsum_path(*operands, **kwargs):
# Compute naive cost
# This isn't quite right, need to look into exactly how einsum does this
- naive_cost = _compute_size_by_dict(indices, dimension_dict)
- indices_in_input = input_subscripts.replace(',', '')
- mult = max(len(input_list) - 1, 1)
- if (len(indices_in_input) - len(set(indices_in_input))):
- mult *= 2
- naive_cost *= mult
+ inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
+ naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
# Compute the path
if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
# Nothing to be optimized, leave it to einsum
path = [tuple(range(len(input_list)))]
elif path_type == "greedy":
- # Maximum memory should be at most out_size for this algorithm
- memory_arg = min(memory_arg, max_size)
path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
elif path_type == "optimal":
path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
@@ -762,18 +914,24 @@ def einsum_path(*operands, **kwargs):
contract = _find_contraction(contract_inds, input_sets, output_set)
out_inds, input_sets, idx_removed, idx_contract = contract
- cost = _compute_size_by_dict(idx_contract, dimension_dict)
- if idx_removed:
- cost *= 2
+ cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
cost_list.append(cost)
scale_list.append(len(idx_contract))
size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
+ bcast = set()
tmp_inputs = []
for x in contract_inds:
tmp_inputs.append(input_list.pop(x))
+ bcast |= broadcast_indices.pop(x)
- do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
+ new_bcast_inds = bcast - idx_removed
+
+ # If we're broadcasting, nix blas
+ if not len(idx_removed & bcast):
+ do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
+ else:
+ do_blas = False
# Last contraction
if (cnum - len(path)) == -1:
@@ -783,6 +941,7 @@ def einsum_path(*operands, **kwargs):
idx_result = "".join([x[1] for x in sorted(sort_result)])
input_list.append(idx_result)
+ broadcast_indices.append(new_bcast_inds)
einsum_str = ",".join(tmp_inputs) + "->" + idx_result
contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
@@ -1200,25 +1359,14 @@ def einsum(*operands, **kwargs):
tmp_operands.append(operands.pop(x))
# Do we need to deal with the output?
- if specified_out and ((num + 1) == len(contraction_list)):
- handle_out = True
+ handle_out = specified_out and ((num + 1) == len(contraction_list))
- # Handle broadcasting vs BLAS cases
+ # Call tensordot if still possible
if blas:
# Checks have already been handled
input_str, results_index = einsum_str.split('->')
input_left, input_right = input_str.split(',')
- if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape:
- left_dims = {dim: size for dim, size in
- zip(input_left, tmp_operands[0].shape)}
- right_dims = {dim: size for dim, size in
- zip(input_right, tmp_operands[1].shape)}
- # If dims do not match we are broadcasting, BLAS off
- if any(left_dims[ind] != right_dims[ind] for ind in idx_rm):
- blas = False
- # Call tensordot if still possible
- if blas:
tensor_result = input_left + input_right
for s in idx_rm:
tensor_result = tensor_result.replace(s, "")
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index a72079218..8ce374a75 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -16,7 +16,7 @@ for size, char in zip(sizes, chars):
global_size_dict[char] = size
-class TestEinSum(object):
+class TestEinsum(object):
def test_einsum_errors(self):
for do_opt in [True, False]:
# Need enough arguments
@@ -614,7 +614,7 @@ class TestEinSum(object):
np.einsum(a, [0, 51], b, [51, 2], [0, 2], optimize=False)
assert_raises(ValueError, lambda: np.einsum(a, [0, 52], b, [52, 2], [0, 2], optimize=False))
assert_raises(ValueError, lambda: np.einsum(a, [-1, 5], b, [5, 2], [-1, 2], optimize=False))
-
+
def test_einsum_broadcast(self):
# Issue #2455 change in handling ellipsis
# remove the 'middle broadcast' error
@@ -735,19 +735,22 @@ class TestEinSum(object):
res = np.einsum('...ij,...jk->...ik', a, a, out=a)
assert res is a
- def optimize_compare(self, string):
+ def optimize_compare(self, subscripts, operands=None):
# Tests all paths of the optimization function against
# conventional einsum
- operands = [string]
- terms = string.split('->')[0].split(',')
- for term in terms:
- dims = [global_size_dict[x] for x in term]
- operands.append(np.random.rand(*dims))
-
- noopt = np.einsum(*operands, optimize=False)
- opt = np.einsum(*operands, optimize='greedy')
+ if operands is None:
+ args = [subscripts]
+ terms = subscripts.split('->')[0].split(',')
+ for term in terms:
+ dims = [global_size_dict[x] for x in term]
+ args.append(np.random.rand(*dims))
+ else:
+ args = [subscripts] + operands
+
+ noopt = np.einsum(*args, optimize=False)
+ opt = np.einsum(*args, optimize='greedy')
assert_almost_equal(opt, noopt)
- opt = np.einsum(*operands, optimize='optimal')
+ opt = np.einsum(*args, optimize='optimal')
assert_almost_equal(opt, noopt)
def test_hadamard_like_products(self):
@@ -833,8 +836,28 @@ class TestEinSum(object):
b = np.einsum('bbcdc->d', a)
assert_equal(b, [12])
+ def test_broadcasting_dot_cases(self):
+ # Ensures broadcasting cases are not mistaken for GEMM
-class TestEinSumPath(object):
+ a = np.random.rand(1, 5, 4)
+ b = np.random.rand(4, 6)
+ c = np.random.rand(5, 6)
+ d = np.random.rand(10)
+
+ self.optimize_compare('ijk,kl,jl', operands=[a, b, c])
+ self.optimize_compare('ijk,kl,jl,i->i', operands=[a, b, c, d])
+
+ e = np.random.rand(1, 1, 5, 4)
+ f = np.random.rand(7, 7)
+ self.optimize_compare('abjk,kl,jl', operands=[e, b, c])
+ self.optimize_compare('abjk,kl,jl,ab->ab', operands=[e, b, c, f])
+
+ # Edge case found in gh-11308
+ g = np.arange(64).reshape(2, 4, 8)
+ self.optimize_compare('obk,ijk->ioj', operands=[g, g])
+
+
+class TestEinsumPath(object):
def build_operands(self, string, size_dict=global_size_dict):
# Builds views based off initial operands
@@ -880,7 +903,7 @@ class TestEinSumPath(object):
long_test1 = self.build_operands('acdf,jbje,gihb,hfac,gfac,gifabc,hfac')
path, path_str = np.einsum_path(*long_test1, optimize='greedy')
self.assert_path_equal(path, ['einsum_path',
- (1, 4), (2, 4), (1, 4), (1, 3), (1, 2), (0, 1)])
+ (3, 6), (3, 4), (2, 4), (2, 3), (0, 2), (0, 1)])
path, path_str = np.einsum_path(*long_test1, optimize='optimal')
self.assert_path_equal(path, ['einsum_path',
@@ -889,10 +912,12 @@ class TestEinSumPath(object):
# Long test 2
long_test2 = self.build_operands('chd,bde,agbc,hiad,bdi,cgh,agdb')
path, path_str = np.einsum_path(*long_test2, optimize='greedy')
+ print(path)
self.assert_path_equal(path, ['einsum_path',
(3, 4), (0, 3), (3, 4), (1, 3), (1, 2), (0, 1)])
path, path_str = np.einsum_path(*long_test2, optimize='optimal')
+ print(path)
self.assert_path_equal(path, ['einsum_path',
(0, 5), (1, 4), (3, 4), (1, 3), (1, 2), (0, 1)])
@@ -926,7 +951,7 @@ class TestEinSumPath(object):
# Edge test4
edge_test4 = self.build_operands('dcc,fce,ea,dbf->ab')
path, path_str = np.einsum_path(*edge_test4, optimize='greedy')
- self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)])
+ self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 1), (0, 1)])
path, path_str = np.einsum_path(*edge_test4, optimize='optimal')
self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 2), (0, 1)])
@@ -949,7 +974,7 @@ class TestEinSumPath(object):
self.assert_path_equal(path, ['einsum_path', (0, 1, 2, 3)])
path, path_str = np.einsum_path(*path_test, optimize=True)
- self.assert_path_equal(path, ['einsum_path', (0, 3), (0, 2), (0, 1)])
+ self.assert_path_equal(path, ['einsum_path', (1, 2), (0, 1), (0, 1)])
exp_path = ['einsum_path', (0, 2), (0, 2), (0, 1)]
path, path_str = np.einsum_path(*path_test, optimize=exp_path)