aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py36
1 files changed, 24 insertions, 12 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index d3d954f33d..fe4a47b9ae 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2298,12 +2298,14 @@ def tensordot(a, b, axes, name=None):
assumes that `a` is the second argument in the contraction operation.
Returns:
- A pair `(reshaped_a, free_dims)` where `reshaped_a` is the tensor `a`
- reshaped to allow contraction via `matmul` and `free_dims` is either a
- list of integers or an `int32` `Tensor`, depending on if `axes` is a list
- and the shape of `a` is fully defined.
+ A tuple `(reshaped_a, free_dims, free_dims_static)` where `reshaped_a` is
+ the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is
+ either a list of integers or an `int32` `Tensor`, depending on whether
+ the shape of a is fully specified, and free_dims_static is either a list
+ of integers and None values, or None, representing the inferred
+ static shape of the free dimensions
+
"""
- # TODO(b/33084409): Implement partial shape inference.
if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)):
shape_a = a.get_shape().as_list()
axes = [i if i >= 0 else i + len(shape_a) for i in axes]
@@ -2314,8 +2316,15 @@ def tensordot(a, b, axes, name=None):
perm = list(axes) + free if flipped else free + list(axes)
new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
- return reshaped_a, free_dims
+ return reshaped_a, free_dims, free_dims
else:
+ if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)):
+ shape_a = a.get_shape().as_list()
+ axes = [i if i >= 0 else i + len(shape_a) for i in axes]
+ free = [i for i in xrange(len(shape_a)) if i not in axes]
+ free_dims_static = [shape_a[i] for i in free]
+ else:
+ free_dims_static = None
shape_a = array_ops.shape(a)
rank_a = array_ops.rank(a)
axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
@@ -2334,7 +2343,7 @@ def tensordot(a, b, axes, name=None):
perm = array_ops.concat([free, axes], 0)
new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
- return reshaped_a, free_dims
+ return reshaped_a, free_dims, free_dims_static
def _tensordot_axes(a, axes):
"""Generates two sets of contraction axes for the two tensor arguments."""
@@ -2366,16 +2375,19 @@ def tensordot(a, b, axes, name=None):
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
a_axes, b_axes = _tensordot_axes(a, axes)
- a_reshape, a_free_dims = _tensordot_reshape(a, a_axes)
- b_reshape, b_free_dims = _tensordot_reshape(b, b_axes, True)
+ a_reshape, a_free_dims, a_free_dims_static = _tensordot_reshape(a, a_axes)
+ b_reshape, b_free_dims, b_free_dims_static = _tensordot_reshape(b, b_axes, True)
ab_matmul = matmul(a_reshape, b_reshape)
if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name)
else:
- a_free_dims = ops.convert_to_tensor(a_free_dims)
- b_free_dims = ops.convert_to_tensor(b_free_dims)
- return array_ops.reshape(
+ a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32)
+ b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)
+ product = array_ops.reshape(
ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)
+ if a_free_dims_static is not None and b_free_dims_static is not None:
+ product.set_shape(a_free_dims_static + b_free_dims_static)
+ return product
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow