aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/special_math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/special_math_ops.py')
-rw-r--r--tensorflow/python/ops/special_math_ops.py54
1 files changed, 26 insertions, 28 deletions
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index bf4d198209..5a8eb432d1 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -318,44 +318,28 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
# into a single axis, and combine multiple summed axes into a
# single axis.
- t0_shape = tuple(x.value for x in t0.get_shape())
+ t0_shape = _get_shape(t0)
num_broadcast_elements_t0 = _total_size(
t0_shape[len(preserved_axes):-len(axes_to_sum)])
num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
- new_shape = t0_shape[:len(preserved_axes)] + (num_broadcast_elements_t0,
- num_summed_elements)
+ new_shape = (t0_shape[:len(preserved_axes)]
+ + [num_broadcast_elements_t0, num_summed_elements])
t0 = _reshape_if_necessary(t0, new_shape)
- t1_shape = tuple(x.value for x in t1.get_shape())
+ t1_shape = _get_shape(t1)
num_broadcast_elements_t1 = _total_size(
t1_shape[len(preserved_axes)+len(axes_to_sum):])
- new_shape = t1_shape[:len(preserved_axes)] + (num_summed_elements,
- num_broadcast_elements_t1)
+ new_shape = (t1_shape[:len(preserved_axes)]
+ + [num_summed_elements, num_broadcast_elements_t1])
t1 = _reshape_if_necessary(t1, new_shape)
product = math_ops.matmul(t0, t1)
# Undo compaction of broadcast axes
uncompacted_shape = (
- t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] +
- t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
+ t0_shape[:len(preserved_axes)+len(broadcast_axes[0])]
+ + t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
)
-
- # Check the number of None values and replace them with Tensors containing
- # corresponding dimensions if there exist two or more None values
- num_none_dims = sum(1 for d in uncompacted_shape if d is None)
- if num_none_dims > 1:
- uncompacted_shape = list(uncompacted_shape)
- for i in xrange(len(uncompacted_shape)):
- if uncompacted_shape[i] is None:
- if i < len(preserved_axes) + len(broadcast_axes[0]):
- uncompacted_shape[i] = array_ops.shape(inputs[0])[i]
- else:
- idx = (i - len(preserved_axes) - len(broadcast_axes[0])
- + len(t1_shape) - len(broadcast_axes[1]))
- uncompacted_shape[i] = array_ops.shape(inputs[1])[idx]
- uncompacted_shape = tuple(uncompacted_shape)
-
product = _reshape_if_necessary(product, uncompacted_shape)
product_axes = (
@@ -386,13 +370,27 @@ def _reshape_if_necessary(tensor, new_shape):
return array_ops.reshape(tensor, new_shape)
+def _get_shape(tensor):
+ """Like get_shape().as_list(), but explicitly queries the shape of a tensor
+ if necessary to ensure that the returned value contains no unknown value."""
+
+ shape = tensor.get_shape().as_list()
+ none_indices = [i for i, d in enumerate(shape) if d is None]
+ if none_indices:
+ # Query the shape if shape contains None values
+ shape_tensor = array_ops.shape(tensor)
+ for i in none_indices:
+ shape[i] = shape_tensor[i]
+ return shape
+
def _total_size(shape_values):
- """Given list of tensor shape values, returns total size or -1 if unknown."""
+ """Given list of tensor shape values, returns total size.
+ If shape_values contains tensor values (which are results of
+ array_ops.shape), then it returns a scalar tensor.
+ If not, it returns an integer."""
+
result = 1
for val in shape_values:
- if val is None:
- return -1
- assert isinstance(val, int)
result *= val
return result