diff options
Diffstat (limited to 'tensorflow/python/ops/special_math_ops.py')
-rw-r--r-- | tensorflow/python/ops/special_math_ops.py | 131 |
1 files changed, 69 insertions, 62 deletions
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index b561203bb4..87561cff92 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -82,7 +82,7 @@ def lbeta(x, name='lbeta'): return result -def einsum(equation, *inputs): +def einsum(equation, *inputs, **kwargs): """A generalized contraction between tensors of arbitrary dimension. This function returns a tensor whose elements are defined by `equation`, @@ -138,6 +138,7 @@ def einsum(equation, *inputs): `numpy.einsum`. *inputs: the inputs to contract (each one a `Tensor`), whose shapes should be consistent with `equation`. + name: A name for the operation (optional). Returns: The contracted `Tensor`, with shape determined by `equation`. @@ -151,70 +152,76 @@ def einsum(equation, *inputs): indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ - if '...' in equation: - raise ValueError('Subscripts with ellipses are not yet supported.') - - match = re.match('([a-z,]+)(->[a-z]*)?', equation) - if not match: - raise ValueError( - 'Indices have incorrect format: %s' % equation - ) - - inputs = list(inputs) - input_axis_labels = match.group(1).split(',') - - if len(inputs) != len(input_axis_labels): - raise ValueError('Got %d arguments for equation "%s", expecting %d' % ( - len(inputs), equation, len(input_axis_labels))) + name = kwargs.pop("name", None) + if kwargs: + raise TypeError("invalid keyword arguments for this function: " + + ", ".join([format(key) + for key in sorted(list(kwargs.keys()))])) + with ops.name_scope(name, "einsum", [equation, inputs]) as name: + if '...' in equation: + raise ValueError('Subscripts with ellipses are not yet supported.') + + match = re.match('([a-z,]+)(->[a-z]*)?', equation) + if not match: + raise ValueError( + 'Indices have incorrect format: %s' % equation + ) - axis_labels = set(''.join(input_axis_labels)) - if match.group(2): - output_axis_labels = match.group(2)[2:] - else: - # infer the output subscripts if not given, assume alphabetical order - indices = ''.join(sorted(axis_labels)) - counts = {ax: 0 for ax in indices} - for axes_ in input_axis_labels: - for ax in axes_: - counts[ax] += 1 + inputs = list(inputs) + input_axis_labels = match.group(1).split(',') - output_axis_labels = ''.join(sorted( - ax for ax in indices - if counts[ax] == 1 - )) + if len(inputs) != len(input_axis_labels): + raise ValueError('Got %d arguments for equation "%s", expecting %d' % ( + len(inputs), equation, len(input_axis_labels))) - for a in axis_labels: - input_count = sum(1 for s in input_axis_labels if a in s) - if input_count > 2 and a not in output_axis_labels: - logging.warn( - 'Falling back to exponential-space implementation of einsum() because' - ' index "%s" is summed over more than two inputs.', a) - return _exponential_space_einsum(equation, *inputs) - - temp = inputs[0] - temp_axis_labels = input_axis_labels[0] - for i in xrange(len(inputs)-1): - axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1]) - - set(output_axis_labels)) - temp, temp_axis_labels = _einsum_reduction(temp, - temp_axis_labels, - inputs[i+1], - input_axis_labels[i+1], - axes_to_sum) - - missing_indices = set(temp_axis_labels) - set(output_axis_labels) - if missing_indices: - reduction_indices = [i for i, a in enumerate(temp_axis_labels) - if a not in output_axis_labels] - temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices) - temp_axis_labels = ''.join(a for a in temp_axis_labels - if a in output_axis_labels) - - if sorted(temp_axis_labels) != sorted(output_axis_labels): - raise ValueError('Invalid equation: %s' % equation) - - perm = [temp_axis_labels.index(a) for a in output_axis_labels] - return _transpose_if_necessary(temp, perm) + axis_labels = set(''.join(input_axis_labels)) + if match.group(2): + output_axis_labels = match.group(2)[2:] + else: + # infer the output subscripts if not given, assume alphabetical order + indices = ''.join(sorted(axis_labels)) + counts = {ax: 0 for ax in indices} + for axes_ in input_axis_labels: + for ax in axes_: + counts[ax] += 1 + + output_axis_labels = ''.join(sorted( + ax for ax in indices + if counts[ax] == 1 + )) + + for a in axis_labels: + input_count = sum(1 for s in input_axis_labels if a in s) + if input_count > 2 and a not in output_axis_labels: + logging.warn( + 'Falling back to exponential-space implementation of einsum() because' + ' index "%s" is summed over more than two inputs.', a) + return _exponential_space_einsum(equation, *inputs) + + temp = inputs[0] + temp_axis_labels = input_axis_labels[0] + for i in xrange(len(inputs)-1): + axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1]) + - set(output_axis_labels)) + temp, temp_axis_labels = _einsum_reduction(temp, + temp_axis_labels, + inputs[i+1], + input_axis_labels[i+1], + axes_to_sum) + + missing_indices = set(temp_axis_labels) - set(output_axis_labels) + if missing_indices: + reduction_indices = [i for i, a in enumerate(temp_axis_labels) + if a not in output_axis_labels] + temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices) + temp_axis_labels = ''.join(a for a in temp_axis_labels + if a in output_axis_labels) + + if sorted(temp_axis_labels) != sorted(output_axis_labels): + raise ValueError('Invalid equation: %s' % equation) + + perm = [temp_axis_labels.index(a) for a in output_axis_labels] + return _transpose_if_necessary(temp, perm) def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): |