aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-15 10:52:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 11:04:22 -0800
commitb8e47b00507d2e9783821c13063e4d94c5cd9809 (patch)
tree2d8b7f0fb5be0b2159e8a9ccc36016f341d41850
parent0a5d4db87c5cba5731ad265ad94e7a3a93c57c46 (diff)
Add tensor contraction op 'tensordot' to opensource TensorFlow. Original implementation by Moritz Hardt (mrtz@google.com). Fixes Github issue #5231.
Change: 142161582
-rw-r--r--tensorflow/python/kernel_tests/BUILD8
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py190
-rw-r--r--tensorflow/python/ops/math_ops.py176
3 files changed, 368 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index e0b9235e91..843904ce31 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2245,6 +2245,14 @@ cuda_py_test(
shard_count = 20,
)
+cuda_py_test(
+ name = "tensordot_op_test",
+ size = "medium",
+ srcs = ["tensordot_op_test.py"],
+ additional_deps = ["//tensorflow:tensorflow_py"],
+ shard_count = 20,
+)
+
sycl_py_test(
name = "basic_gpu_test",
size = "small",
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
new file mode 100644
index 0000000000..51dcf82431
--- /dev/null
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -0,0 +1,190 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.math_ops.matmul."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+_MAXDIM = 5
+
+
+def _add_test(test, test_name, fn):
+ test_name = "_".join(["test", test_name])
+ if hasattr(test, test_name):
+ raise RuntimeError("Test %s defined more than once" % test_name)
+ setattr(test, test_name, fn)
+
+
+class TensordotTest(tf.test.TestCase):
+
+ def test_invalid_shape(self):
+ a = [[1, 2], [3, 4]]
+ b = [[1, 2], [3, 4], [5, 6]]
+ a_axes = [1]
+ b_axes = [0]
+ # Invalid static shapes.
+ with self.assertRaises(ValueError):
+ tf.tensordot(a, b, (a_axes, b_axes))
+ # Invalid dynamic shapes.
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
+ "Matrix size-incompatible"):
+ a_ph = tf.placeholder(tf.float32)
+ b_ph = tf.placeholder(tf.float32)
+ axes_ph = tf.placeholder(tf.int32)
+ output = tf.tensordot(a_ph, b_ph, axes_ph)
+ _ = sess.run([output],
+ feed_dict={a_ph: a,
+ b_ph: b,
+ axes_ph: (a_axes, b_axes)})
+
+ def test_invalid_axes(self):
+ a = [[1, 2], [3, 4]]
+ b = [[1, 2], [3, 4]]
+ # Invalid static axes.
+ for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]:
+ with self.assertRaises(ValueError):
+ tf.tensordot(a, b, axes_value)
+
+ with self.assertRaises(IndexError):
+ tf.tensordot(a, b, [[0], [7]])
+
+ # Invalid dynamic axes.
+ a_ph = tf.placeholder(tf.float32)
+ b_ph = tf.placeholder(tf.float32)
+ axes_ph = tf.placeholder(tf.int32)
+ output = tf.tensordot(a_ph, b_ph, axes_ph)
+ # Note: We don't support scalar Tensor values for axes.
+ for axes_value in 1, [1], [0,1], [[1]], [[0,1]], [[0], [7]]:
+ with self.test_session() as sess:
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ _ = sess.run([output],
+ feed_dict={a_ph: a,
+ b_ph: b,
+ axes_ph: axes_value})
+
+ def test_no_partial_shape_inference(self):
+ # If one of the shapes is only partially defined, the output shape is
+ # unknown.
+ a = tf.placeholder(tf.float32)
+ b = tf.placeholder(tf.float32)
+ axes = ([1], [0])
+ output = tf.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+ a.set_shape([None, 2])
+ b.set_shape([2, 3])
+ output = tf.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+ a = tf.placeholder(tf.float32)
+ b = tf.placeholder(tf.float32)
+ a.set_shape([2, 2])
+ b.set_shape([2, None])
+ output = tf.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+
+
+def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
+
+ # Select a random subset of size m from [0, 1, ..., n-1].
+ def _random_subset(m, n):
+ assert m <= n
+ return (np.random.permutation(n)[:m]).astype(np.int32)
+
+ def _generate_random_tensors_and_dims():
+ a_shape = np.random.random_integers(1, _MAXDIM, rank_a_)
+ b_shape = np.random.random_integers(1, _MAXDIM, rank_b_)
+ shared_shape = np.random.random_integers(1, _MAXDIM, num_dims_)
+ a_dims = _random_subset(num_dims_, rank_a_)
+ b_dims = _random_subset(num_dims_, rank_b_)
+ for i in range(num_dims_):
+ a_shape[a_dims[i]] = shared_shape[i]
+ b_shape[b_dims[i]] = shared_shape[i]
+ a = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(a_shape)).reshape(a_shape).astype(dtype_)
+ b = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(b_shape)).reshape(b_shape).astype(dtype_)
+ return a, b, a_dims, b_dims
+
+ def test_tensordot(self):
+ num_trials = min(30, num_dims_ * num_dims_)
+ if dtype_ == np.float16:
+ tol = 0.05
+ elif dtype_ == np.float32 or dtype_ == np.complex64:
+ tol = 1e-5
+ else:
+ tol = 1e-12
+ for _ in range(num_trials):
+ a_np, b_np, a_dims_np, b_dims_np = _generate_random_tensors_and_dims()
+ np_ans = np.tensordot(a_np, b_np, axes=(a_dims_np, b_dims_np))
+ with self.test_session(use_gpu=True) as sess:
+ if dynamic_shape_:
+ a = tf.placeholder(dtype_)
+ b = tf.placeholder(dtype_)
+ axes = tf.placeholder(tf.int32)
+ c = tf.tensordot(a, b, axes)
+ tf_ans = sess.run(
+ c, feed_dict={a: a_np,
+ b: b_np,
+ axes: (a_dims_np, b_dims_np)})
+ else:
+ tf_ans = tf.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval()
+ self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
+ self.assertAllEqual(tf_ans.shape, np_ans.shape)
+
+ def test_tensordot_scalar_axes(self):
+ if num_dims_ < 1:
+ self.skipTest("Not a test")
+ if dtype_ == np.float16:
+ tol = 0.05
+ elif dtype_ == np.float32 or dtype_ == np.complex64:
+ tol = 1e-5
+ else:
+ tol = 1e-12
+ shape = [5] * num_dims_
+ a_np = np.random.uniform(
+ low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
+ b_np = np.random.uniform(
+ low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
+ all_axes = [1]
+ if a_np.ndim > 1:
+ all_axes.append(a_np.ndim - 1)
+ for axes in all_axes:
+ np_ans = np.tensordot(a_np, b_np, axes=axes)
+ with self.test_session(use_gpu=True):
+ tf_ans = tf.tensordot(a_np, b_np, axes=axes).eval()
+ self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol)
+ self.assertAllEqual(tf_ans.shape, np_ans.shape)
+
+ return [test_tensordot, test_tensordot_scalar_axes]
+
+
+if __name__ == "__main__":
+ for dtype in np.float16, np.float32, np.float64, np.complex64, np.complex128:
+ for rank_a in 1, 2, 4, 5:
+ for rank_b in 1, 2, 4, 5:
+ for num_dims in range(0, min(rank_a, rank_b) + 1):
+ for dynamic_shape in False, True:
+ for testcase in _get_tensordot_tests(dtype, rank_a, rank_b,
+ num_dims, dynamic_shape):
+ name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__,
+ rank_a, rank_b, num_dims,
+ dynamic_shape)
+ _add_test(TensordotTest, name, testcase)
+ tf.test.main()
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index e4880278de..4dc027b58b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -109,6 +109,15 @@ functions on matrices to your graph.
@@self_adjoint_eigvals
@@svd
+
+## Tensor Math Function
+
+TensorFlow provides operations that you can use to add tensor functions to your
+graph.
+
+@@tensordot
+
+
## Complex Number Functions
TensorFlow provides several operations that you can use to add complex number
@@ -216,6 +225,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
@@ -235,6 +245,8 @@ from tensorflow.python.ops import state_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.util import compat
+
# Aliases for some automatically-generated names.
linspace = gen_math_ops.lin_space
@@ -250,8 +262,8 @@ def argmax(input, axis=None, name=None, dimension=None):
return gen_math_ops.arg_max(input, axis, name)
-argmax.__doc__ = (gen_math_ops.arg_max.__doc__
- .replace("dimensions", "axes").replace("dimension", "axis"))
+argmax.__doc__ = (gen_math_ops.arg_max.__doc__.replace(
+ "dimensions", "axes").replace("dimension", "axis"))
# TODO(aselle:deprecate arg_min)
@@ -263,8 +275,8 @@ def argmin(input, axis=None, name=None, dimension=None):
return gen_math_ops.arg_min(input, axis, name)
-argmin.__doc__ = (gen_math_ops.arg_min.__doc__
- .replace("dimensions", "axes").replace("dimension", "axis"))
+argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace(
+ "dimensions", "axes").replace("dimension", "axis"))
# pylint: enable=redefined-builtin
@@ -510,8 +522,9 @@ def pow(x, y, name=None):
return gen_math_ops._pow(x, y, name=name)
+# pylint: disable=redefined-outer-name,redefined-builtin
def complex(real, imag, name=None):
- """Converts two real numbers to a complex number.
+ r"""Converts two real numbers to a complex number.
Given a tensor `real` representing the real part of a complex number, and a
tensor `imag` representing the imaginary part of a complex number, this
@@ -551,7 +564,7 @@ def complex(real, imag, name=None):
def real(input, name=None):
- """Returns the real part of a complex number.
+ r"""Returns the real part of a complex number.
Given a tensor `input` of complex numbers, this operation returns a tensor of
type `float32` or `float64` that is the real part of each element in `input`.
@@ -610,6 +623,9 @@ def imag(input, name=None):
return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
+# pylint: enable=redefined-outer-name,redefined-builtin
+
+
def round(x, name=None):
"""Rounds the values of a tensor to the nearest integer, element-wise.
@@ -1042,6 +1058,7 @@ def _mul_dispatch(x, y, name=None):
y.dense_shape, x, name)
return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
+
# NOTE(aselle): When integer division is added for sparse_dense_cwise,
# div, truediv, and floordiv should be delegated appropriately for
# Python sematnics, analogous to dense cwise tensor operations.
@@ -2165,3 +2182,150 @@ def reduced_shape(input_shape, axes):
input_shape, # [2, 3, 5, 7]
array_ops.fill(axes_shape, 1)
]) # [1, 1]
+
+
+def tensordot(a, b, axes, name=None):
+ r"""Tensor contraction of a and b along specified axes.
+
+ Tensordot (also known as tensor contraction) sums the product of elements
+ from `a` and `b` over the indices specified by `a_axes` and `b_axes`.
+ The lists `a_axes` and `b_axes` specify those pairs of axes along which to
+ contract the tensors. The axis `a_axes[i]` of `a` must have the same dimension
+ as axis `b_axes[i]` of `b` for all `i` in `range(0, len(a_axes))`. The lists
+ `a_axes` and `b_axes` must have identical length and consist of unique
+ integers that specify valid axes for each of the tensors.
+
+ This operation corresponds to `numpy.tensordot(a, b, axes)`.
+
+ Example 1: When `a` and `b` are matrices (order 2), the case `axes = 1`
+ is equivalent to matrix multiplication.
+
+ Example 2: When `a` and `b` are matrices (order 2), the case
+ `axes = [[1], [0]]` is equivalent to matrix multiplication.
+
+ Example 3: Suppose that \\(a_ijk\\) and \\(b_lmn\\) represent two
+ tensors of order 3. Then, `contract(a, b, [0], [2])` is the order 4 tensor
+ \\(c_{jklm}\\) whose entry
+ corresponding to the indices \\((j,k,l,m)\\) is given by:
+
+ \\( c_{jklm} = \sum_i a_{ijk} b_{lmi} \\).
+
+ In general, `order(c) = order(a) + order(b) - 2*len(axes[0])`.
+
+ Args:
+ a: `Tensor` of type `float32` or `float64`.
+ b: `Tensor` with the same type as `a`.
+ axes: Either a scalar `N`, or a list or an `int32` `Tensor` of shape [2, k].
+ If axes is a scalar, sum over the last N axes of a and the first N axes
+ of b in order.
+ If axes is a list or `Tensor` the first and second row contain the set of
+ unique integers specifying axes along which the contraction is computed,
+ for `a` and `b`, respectively. The number of axes for `a` and `b` must
+ be equal.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` with the same type as `a`.
+
+ Raises:
+ ValueError: If the shapes of `a`, `b`, and `axes` are incompatible.
+ IndexError: If the values in axes exceed the rank of the corresponding
+ tensor.
+ """
+
+ def _tensordot_reshape(a, axes, flipped=False):
+ """Helper method to perform transpose and reshape for contraction op.
+
+ This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul`
+ using `array_ops.transpose` and `array_ops.reshape`. The method takes a
+ tensor and performs the correct transpose and reshape operation for a given
+ set of indices. It returns the reshaped tensor as well as a list of indices
+ necesary to reshape the tensor again after matrix multiplication.
+
+ Args:
+ a: `Tensor`.
+ axes: List or `int32` `Tensor` of unique indices specifying valid axes of
+ `a`.
+ flipped: An optional `bool`. Defaults to `False`. If `True`, the method
+ assumes that `a` is the second argument in the contraction operation.
+
+ Returns:
+ A pair `(reshaped_a, free_dims)` where `reshaped_a` is the tensor `a`
+ reshaped to allow contraction via `matmul` and `free_dims` is either a
+ list of integers or an `int32` `Tensor`, depending on if `axes` is a list
+ and the shape of `a` is fully defined.
+ """
+ # TODO(b/33084409): Implement partial shape inference.
+ if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)):
+ shape_a = a.get_shape().as_list()
+ axes = [i if i >= 0 else i + len(shape_a) for i in axes]
+ free = [i for i in xrange(len(shape_a)) if i not in axes]
+ free_dims = [shape_a[i] for i in free]
+ prod_free = int(np.prod([shape_a[i] for i in free]))
+ prod_axes = int(np.prod([shape_a[i] for i in axes]))
+ perm = list(axes) + free if flipped else free + list(axes)
+ new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
+ reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
+ return reshaped_a, free_dims
+ else:
+ shape_a = array_ops.shape(a)
+ rank_a = array_ops.rank(a)
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ axes = cast(axes >= 0, dtypes.int32) * axes + cast(
+ axes < 0, dtypes.int32) * (axes + rank_a)
+ free, _ = array_ops.setdiff1d(range(rank_a), axes)
+ free_dims = array_ops.gather(shape_a, free)
+ axes_dims = array_ops.gather(shape_a, axes)
+ prod_free_dims = reduce_prod(free_dims)
+ prod_axes_dims = reduce_prod(axes_dims)
+ perm = array_ops.concat(0, [axes_dims, free_dims])
+ if flipped:
+ perm = array_ops.concat(0, [axes, free])
+ new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
+ else:
+ perm = array_ops.concat(0, [free, axes])
+ new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
+ reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
+ return reshaped_a, free_dims
+
+ def _tensordot_axes(a, axes):
+ """Generates two sets of contraction axes for the two tensor arguments."""
+ a_shape = a.get_shape()
+ if isinstance(axes, compat.integral_types):
+ if axes < 1:
+ raise ValueError("'axes' must be at least 1.")
+ if a_shape.ndims is not None:
+ return range(a_shape.ndims - axes, a_shape.ndims), range(axes)
+ else:
+ rank = array_ops.rank(a)
+ return (array_ops.range(
+ rank - axes, rank, dtype=dtypes.int32), array_ops.range(
+ rank, dtype=dtypes.int32))
+ elif isinstance(axes, (list, tuple)):
+ if len(axes) != 2:
+ raise ValueError("'axes' must be an integer or have length 2.")
+ a_axes = axes[0]
+ b_axes = axes[1]
+ if len(a_axes) != len(b_axes):
+ raise ValueError(
+ "Different number of contraction axes 'a' and 'b', %s != %s.",
+ len(a_axes), len(b_axes))
+ return a_axes, b_axes
+ else:
+ axes = ops.convert_to_tensor(axes, name="axes", dtype=dtypes.int32)
+ return axes[0], axes[1]
+
+ with ops.name_scope(name, "Tensordot", [a, b, axes]) as name:
+ a = ops.convert_to_tensor(a, name="a")
+ b = ops.convert_to_tensor(b, name="b")
+ a_axes, b_axes = _tensordot_axes(a, axes)
+ a_reshape, a_free_dims = _tensordot_reshape(a, a_axes)
+ b_reshape, b_free_dims = _tensordot_reshape(b, b_axes, True)
+ ab_matmul = matmul(a_reshape, b_reshape)
+ if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
+ return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name)
+ else:
+ a_free_dims = ops.convert_to_tensor(a_free_dims)
+ b_free_dims = ops.convert_to_tensor(b_free_dims)
+ return array_ops.reshape(
+ ab_matmul, array_ops.concat(0, [a_free_dims, b_free_dims]), name=name)