aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/special_math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/special_math_ops.py')
-rw-r--r--tensorflow/python/ops/special_math_ops.py131
1 files changed, 69 insertions, 62 deletions
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index b561203bb4..87561cff92 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -82,7 +82,7 @@ def lbeta(x, name='lbeta'):
return result
-def einsum(equation, *inputs):
+def einsum(equation, *inputs, **kwargs):
"""A generalized contraction between tensors of arbitrary dimension.
This function returns a tensor whose elements are defined by `equation`,
@@ -138,6 +138,7 @@ def einsum(equation, *inputs):
`numpy.einsum`.
*inputs: the inputs to contract (each one a `Tensor`), whose shapes should
be consistent with `equation`.
+ name: A name for the operation (optional).
Returns:
The contracted `Tensor`, with shape determined by `equation`.
@@ -151,70 +152,76 @@ def einsum(equation, *inputs):
indices in its subscript, or
- the input shapes are inconsistent along a particular axis.
"""
- if '...' in equation:
- raise ValueError('Subscripts with ellipses are not yet supported.')
-
- match = re.match('([a-z,]+)(->[a-z]*)?', equation)
- if not match:
- raise ValueError(
- 'Indices have incorrect format: %s' % equation
- )
-
- inputs = list(inputs)
- input_axis_labels = match.group(1).split(',')
-
- if len(inputs) != len(input_axis_labels):
- raise ValueError('Got %d arguments for equation "%s", expecting %d' % (
- len(inputs), equation, len(input_axis_labels)))
+ name = kwargs.pop("name", None)
+ if kwargs:
+ raise TypeError("invalid keyword arguments for this function: " +
+ ", ".join([format(key)
+ for key in sorted(list(kwargs.keys()))]))
+ with ops.name_scope(name, "einsum", [equation, inputs]) as name:
+ if '...' in equation:
+ raise ValueError('Subscripts with ellipses are not yet supported.')
+
+ match = re.match('([a-z,]+)(->[a-z]*)?', equation)
+ if not match:
+ raise ValueError(
+ 'Indices have incorrect format: %s' % equation
+ )
- axis_labels = set(''.join(input_axis_labels))
- if match.group(2):
- output_axis_labels = match.group(2)[2:]
- else:
- # infer the output subscripts if not given, assume alphabetical order
- indices = ''.join(sorted(axis_labels))
- counts = {ax: 0 for ax in indices}
- for axes_ in input_axis_labels:
- for ax in axes_:
- counts[ax] += 1
+ inputs = list(inputs)
+ input_axis_labels = match.group(1).split(',')
- output_axis_labels = ''.join(sorted(
- ax for ax in indices
- if counts[ax] == 1
- ))
+ if len(inputs) != len(input_axis_labels):
+ raise ValueError('Got %d arguments for equation "%s", expecting %d' % (
+ len(inputs), equation, len(input_axis_labels)))
- for a in axis_labels:
- input_count = sum(1 for s in input_axis_labels if a in s)
- if input_count > 2 and a not in output_axis_labels:
- logging.warn(
- 'Falling back to exponential-space implementation of einsum() because'
- ' index "%s" is summed over more than two inputs.', a)
- return _exponential_space_einsum(equation, *inputs)
-
- temp = inputs[0]
- temp_axis_labels = input_axis_labels[0]
- for i in xrange(len(inputs)-1):
- axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1])
- - set(output_axis_labels))
- temp, temp_axis_labels = _einsum_reduction(temp,
- temp_axis_labels,
- inputs[i+1],
- input_axis_labels[i+1],
- axes_to_sum)
-
- missing_indices = set(temp_axis_labels) - set(output_axis_labels)
- if missing_indices:
- reduction_indices = [i for i, a in enumerate(temp_axis_labels)
- if a not in output_axis_labels]
- temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices)
- temp_axis_labels = ''.join(a for a in temp_axis_labels
- if a in output_axis_labels)
-
- if sorted(temp_axis_labels) != sorted(output_axis_labels):
- raise ValueError('Invalid equation: %s' % equation)
-
- perm = [temp_axis_labels.index(a) for a in output_axis_labels]
- return _transpose_if_necessary(temp, perm)
+ axis_labels = set(''.join(input_axis_labels))
+ if match.group(2):
+ output_axis_labels = match.group(2)[2:]
+ else:
+ # infer the output subscripts if not given, assume alphabetical order
+ indices = ''.join(sorted(axis_labels))
+ counts = {ax: 0 for ax in indices}
+ for axes_ in input_axis_labels:
+ for ax in axes_:
+ counts[ax] += 1
+
+ output_axis_labels = ''.join(sorted(
+ ax for ax in indices
+ if counts[ax] == 1
+ ))
+
+ for a in axis_labels:
+ input_count = sum(1 for s in input_axis_labels if a in s)
+ if input_count > 2 and a not in output_axis_labels:
+ logging.warn(
+ 'Falling back to exponential-space implementation of einsum() because'
+ ' index "%s" is summed over more than two inputs.', a)
+ return _exponential_space_einsum(equation, *inputs)
+
+ temp = inputs[0]
+ temp_axis_labels = input_axis_labels[0]
+ for i in xrange(len(inputs)-1):
+ axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1])
+ - set(output_axis_labels))
+ temp, temp_axis_labels = _einsum_reduction(temp,
+ temp_axis_labels,
+ inputs[i+1],
+ input_axis_labels[i+1],
+ axes_to_sum)
+
+ missing_indices = set(temp_axis_labels) - set(output_axis_labels)
+ if missing_indices:
+ reduction_indices = [i for i, a in enumerate(temp_axis_labels)
+ if a not in output_axis_labels]
+ temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices)
+ temp_axis_labels = ''.join(a for a in temp_axis_labels
+ if a in output_axis_labels)
+
+ if sorted(temp_axis_labels) != sorted(output_axis_labels):
+ raise ValueError('Invalid equation: %s' % equation)
+
+ perm = [temp_axis_labels.index(a) for a in output_axis_labels]
+ return _transpose_if_necessary(temp, perm)
def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):