diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index d3d954f33d..fe4a47b9ae 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2298,12 +2298,14 @@ def tensordot(a, b, axes, name=None): assumes that `a` is the second argument in the contraction operation. Returns: - A pair `(reshaped_a, free_dims)` where `reshaped_a` is the tensor `a` - reshaped to allow contraction via `matmul` and `free_dims` is either a - list of integers or an `int32` `Tensor`, depending on if `axes` is a list - and the shape of `a` is fully defined. + A tuple `(reshaped_a, free_dims, free_dims_static)` where `reshaped_a` is + the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is + either a list of integers or an `int32` `Tensor`, depending on whether + the shape of a is fully specified, and free_dims_static is either a list + of integers and None values, or None, representing the inferred + static shape of the free dimensions + """ - # TODO(b/33084409): Implement partial shape inference. if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)): shape_a = a.get_shape().as_list() axes = [i if i >= 0 else i + len(shape_a) for i in axes] @@ -2314,8 +2316,15 @@ def tensordot(a, b, axes, name=None): perm = list(axes) + free if flipped else free + list(axes) new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes] reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape) - return reshaped_a, free_dims + return reshaped_a, free_dims, free_dims else: + if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)): + shape_a = a.get_shape().as_list() + axes = [i if i >= 0 else i + len(shape_a) for i in axes] + free = [i for i in xrange(len(shape_a)) if i not in axes] + free_dims_static = [shape_a[i] for i in free] + else: + free_dims_static = None shape_a = array_ops.shape(a) rank_a = array_ops.rank(a) axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes") @@ -2334,7 +2343,7 @@ def tensordot(a, b, axes, name=None): perm = array_ops.concat([free, axes], 0) new_shape = array_ops.stack([prod_free_dims, prod_axes_dims]) reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape) - return reshaped_a, free_dims + return reshaped_a, free_dims, free_dims_static def _tensordot_axes(a, axes): """Generates two sets of contraction axes for the two tensor arguments.""" @@ -2366,16 +2375,19 @@ def tensordot(a, b, axes, name=None): a = ops.convert_to_tensor(a, name="a") b = ops.convert_to_tensor(b, name="b") a_axes, b_axes = _tensordot_axes(a, axes) - a_reshape, a_free_dims = _tensordot_reshape(a, a_axes) - b_reshape, b_free_dims = _tensordot_reshape(b, b_axes, True) + a_reshape, a_free_dims, a_free_dims_static = _tensordot_reshape(a, a_axes) + b_reshape, b_free_dims, b_free_dims_static = _tensordot_reshape(b, b_axes, True) ab_matmul = matmul(a_reshape, b_reshape) if isinstance(a_free_dims, list) and isinstance(b_free_dims, list): return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name) else: - a_free_dims = ops.convert_to_tensor(a_free_dims) - b_free_dims = ops.convert_to_tensor(b_free_dims) - return array_ops.reshape( + a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32) + b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32) + product = array_ops.reshape( ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name) + if a_free_dims_static is not None and b_free_dims_static is not None: + product.set_shape(a_free_dims_static + b_free_dims_static) + return product # FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow |