diff options
author | 2016-12-15 10:52:57 -0800 | |
---|---|---|
committer | 2016-12-15 11:04:22 -0800 | |
commit | b8e47b00507d2e9783821c13063e4d94c5cd9809 (patch) | |
tree | 2d8b7f0fb5be0b2159e8a9ccc36016f341d41850 | |
parent | 0a5d4db87c5cba5731ad265ad94e7a3a93c57c46 (diff) |
Add tensor contraction op 'tensordot' to opensource TensorFlow. Original implementation by Moritz Hardt (mrtz@google.com). Fixes Github issue #5231.
Change: 142161582
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/tensordot_op_test.py | 190 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 176 |
3 files changed, 368 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index e0b9235e91..843904ce31 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2245,6 +2245,14 @@ cuda_py_test( shard_count = 20, ) +cuda_py_test( + name = "tensordot_op_test", + size = "medium", + srcs = ["tensordot_op_test.py"], + additional_deps = ["//tensorflow:tensorflow_py"], + shard_count = 20, +) + sycl_py_test( name = "basic_gpu_test", size = "small", diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py new file mode 100644 index 0000000000..51dcf82431 --- /dev/null +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -0,0 +1,190 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops.matmul.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +_MAXDIM = 5 + + +def _add_test(test, test_name, fn): + test_name = "_".join(["test", test_name]) + if hasattr(test, test_name): + raise RuntimeError("Test %s defined more than once" % test_name) + setattr(test, test_name, fn) + + +class TensordotTest(tf.test.TestCase): + + def test_invalid_shape(self): + a = [[1, 2], [3, 4]] + b = [[1, 2], [3, 4], [5, 6]] + a_axes = [1] + b_axes = [0] + # Invalid static shapes. + with self.assertRaises(ValueError): + tf.tensordot(a, b, (a_axes, b_axes)) + # Invalid dynamic shapes. + with self.test_session() as sess: + with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, + "Matrix size-incompatible"): + a_ph = tf.placeholder(tf.float32) + b_ph = tf.placeholder(tf.float32) + axes_ph = tf.placeholder(tf.int32) + output = tf.tensordot(a_ph, b_ph, axes_ph) + _ = sess.run([output], + feed_dict={a_ph: a, + b_ph: b, + axes_ph: (a_axes, b_axes)}) + + def test_invalid_axes(self): + a = [[1, 2], [3, 4]] + b = [[1, 2], [3, 4]] + # Invalid static axes. + for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]: + with self.assertRaises(ValueError): + tf.tensordot(a, b, axes_value) + + with self.assertRaises(IndexError): + tf.tensordot(a, b, [[0], [7]]) + + # Invalid dynamic axes. + a_ph = tf.placeholder(tf.float32) + b_ph = tf.placeholder(tf.float32) + axes_ph = tf.placeholder(tf.int32) + output = tf.tensordot(a_ph, b_ph, axes_ph) + # Note: We don't support scalar Tensor values for axes. + for axes_value in 1, [1], [0,1], [[1]], [[0,1]], [[0], [7]]: + with self.test_session() as sess: + with self.assertRaises(tf.errors.InvalidArgumentError): + _ = sess.run([output], + feed_dict={a_ph: a, + b_ph: b, + axes_ph: axes_value}) + + def test_no_partial_shape_inference(self): + # If one of the shapes is only partially defined, the output shape is + # unknown. + a = tf.placeholder(tf.float32) + b = tf.placeholder(tf.float32) + axes = ([1], [0]) + output = tf.tensordot(a, b, axes) + self.assertEqual(output.get_shape().ndims, None) + a.set_shape([None, 2]) + b.set_shape([2, 3]) + output = tf.tensordot(a, b, axes) + self.assertEqual(output.get_shape().ndims, None) + a = tf.placeholder(tf.float32) + b = tf.placeholder(tf.float32) + a.set_shape([2, 2]) + b.set_shape([2, None]) + output = tf.tensordot(a, b, axes) + self.assertEqual(output.get_shape().ndims, None) + + +def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): + + # Select a random subset of size m from [0, 1, ..., n-1]. + def _random_subset(m, n): + assert m <= n + return (np.random.permutation(n)[:m]).astype(np.int32) + + def _generate_random_tensors_and_dims(): + a_shape = np.random.random_integers(1, _MAXDIM, rank_a_) + b_shape = np.random.random_integers(1, _MAXDIM, rank_b_) + shared_shape = np.random.random_integers(1, _MAXDIM, num_dims_) + a_dims = _random_subset(num_dims_, rank_a_) + b_dims = _random_subset(num_dims_, rank_b_) + for i in range(num_dims_): + a_shape[a_dims[i]] = shared_shape[i] + b_shape[b_dims[i]] = shared_shape[i] + a = np.random.uniform( + low=-1.0, high=1.0, + size=np.prod(a_shape)).reshape(a_shape).astype(dtype_) + b = np.random.uniform( + low=-1.0, high=1.0, + size=np.prod(b_shape)).reshape(b_shape).astype(dtype_) + return a, b, a_dims, b_dims + + def test_tensordot(self): + num_trials = min(30, num_dims_ * num_dims_) + if dtype_ == np.float16: + tol = 0.05 + elif dtype_ == np.float32 or dtype_ == np.complex64: + tol = 1e-5 + else: + tol = 1e-12 + for _ in range(num_trials): + a_np, b_np, a_dims_np, b_dims_np = _generate_random_tensors_and_dims() + np_ans = np.tensordot(a_np, b_np, axes=(a_dims_np, b_dims_np)) + with self.test_session(use_gpu=True) as sess: + if dynamic_shape_: + a = tf.placeholder(dtype_) + b = tf.placeholder(dtype_) + axes = tf.placeholder(tf.int32) + c = tf.tensordot(a, b, axes) + tf_ans = sess.run( + c, feed_dict={a: a_np, + b: b_np, + axes: (a_dims_np, b_dims_np)}) + else: + tf_ans = tf.tensordot(a_np, b_np, (a_dims_np, b_dims_np)).eval() + self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) + self.assertAllEqual(tf_ans.shape, np_ans.shape) + + def test_tensordot_scalar_axes(self): + if num_dims_ < 1: + self.skipTest("Not a test") + if dtype_ == np.float16: + tol = 0.05 + elif dtype_ == np.float32 or dtype_ == np.complex64: + tol = 1e-5 + else: + tol = 1e-12 + shape = [5] * num_dims_ + a_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) + b_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) + all_axes = [1] + if a_np.ndim > 1: + all_axes.append(a_np.ndim - 1) + for axes in all_axes: + np_ans = np.tensordot(a_np, b_np, axes=axes) + with self.test_session(use_gpu=True): + tf_ans = tf.tensordot(a_np, b_np, axes=axes).eval() + self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) + self.assertAllEqual(tf_ans.shape, np_ans.shape) + + return [test_tensordot, test_tensordot_scalar_axes] + + +if __name__ == "__main__": + for dtype in np.float16, np.float32, np.float64, np.complex64, np.complex128: + for rank_a in 1, 2, 4, 5: + for rank_b in 1, 2, 4, 5: + for num_dims in range(0, min(rank_a, rank_b) + 1): + for dynamic_shape in False, True: + for testcase in _get_tensordot_tests(dtype, rank_a, rank_b, + num_dims, dynamic_shape): + name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__, + rank_a, rank_b, num_dims, + dynamic_shape) + _add_test(TensordotTest, name, testcase) + tf.test.main() diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index e4880278de..4dc027b58b 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -109,6 +109,15 @@ functions on matrices to your graph. @@self_adjoint_eigvals @@svd + +## Tensor Math Function + +TensorFlow provides operations that you can use to add tensor functions to your +graph. + +@@tensordot + + ## Complex Number Functions TensorFlow provides several operations that you can use to add complex number @@ -216,6 +225,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op @@ -235,6 +245,8 @@ from tensorflow.python.ops import state_ops # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import +from tensorflow.python.util import compat + # Aliases for some automatically-generated names. linspace = gen_math_ops.lin_space @@ -250,8 +262,8 @@ def argmax(input, axis=None, name=None, dimension=None): return gen_math_ops.arg_max(input, axis, name) -argmax.__doc__ = (gen_math_ops.arg_max.__doc__ - .replace("dimensions", "axes").replace("dimension", "axis")) +argmax.__doc__ = (gen_math_ops.arg_max.__doc__.replace( + "dimensions", "axes").replace("dimension", "axis")) # TODO(aselle:deprecate arg_min) @@ -263,8 +275,8 @@ def argmin(input, axis=None, name=None, dimension=None): return gen_math_ops.arg_min(input, axis, name) -argmin.__doc__ = (gen_math_ops.arg_min.__doc__ - .replace("dimensions", "axes").replace("dimension", "axis")) +argmin.__doc__ = (gen_math_ops.arg_min.__doc__.replace( + "dimensions", "axes").replace("dimension", "axis")) # pylint: enable=redefined-builtin @@ -510,8 +522,9 @@ def pow(x, y, name=None): return gen_math_ops._pow(x, y, name=name) +# pylint: disable=redefined-outer-name,redefined-builtin def complex(real, imag, name=None): - """Converts two real numbers to a complex number. + r"""Converts two real numbers to a complex number. Given a tensor `real` representing the real part of a complex number, and a tensor `imag` representing the imaginary part of a complex number, this @@ -551,7 +564,7 @@ def complex(real, imag, name=None): def real(input, name=None): - """Returns the real part of a complex number. + r"""Returns the real part of a complex number. Given a tensor `input` of complex numbers, this operation returns a tensor of type `float32` or `float64` that is the real part of each element in `input`. @@ -610,6 +623,9 @@ def imag(input, name=None): return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name) +# pylint: enable=redefined-outer-name,redefined-builtin + + def round(x, name=None): """Rounds the values of a tensor to the nearest integer, element-wise. @@ -1042,6 +1058,7 @@ def _mul_dispatch(x, y, name=None): y.dense_shape, x, name) return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape) + # NOTE(aselle): When integer division is added for sparse_dense_cwise, # div, truediv, and floordiv should be delegated appropriately for # Python sematnics, analogous to dense cwise tensor operations. @@ -2165,3 +2182,150 @@ def reduced_shape(input_shape, axes): input_shape, # [2, 3, 5, 7] array_ops.fill(axes_shape, 1) ]) # [1, 1] + + +def tensordot(a, b, axes, name=None): + r"""Tensor contraction of a and b along specified axes. + + Tensordot (also known as tensor contraction) sums the product of elements + from `a` and `b` over the indices specified by `a_axes` and `b_axes`. + The lists `a_axes` and `b_axes` specify those pairs of axes along which to + contract the tensors. The axis `a_axes[i]` of `a` must have the same dimension + as axis `b_axes[i]` of `b` for all `i` in `range(0, len(a_axes))`. The lists + `a_axes` and `b_axes` must have identical length and consist of unique + integers that specify valid axes for each of the tensors. + + This operation corresponds to `numpy.tensordot(a, b, axes)`. + + Example 1: When `a` and `b` are matrices (order 2), the case `axes = 1` + is equivalent to matrix multiplication. + + Example 2: When `a` and `b` are matrices (order 2), the case + `axes = [[1], [0]]` is equivalent to matrix multiplication. + + Example 3: Suppose that \\(a_ijk\\) and \\(b_lmn\\) represent two + tensors of order 3. Then, `contract(a, b, [0], [2])` is the order 4 tensor + \\(c_{jklm}\\) whose entry + corresponding to the indices \\((j,k,l,m)\\) is given by: + + \\( c_{jklm} = \sum_i a_{ijk} b_{lmi} \\). + + In general, `order(c) = order(a) + order(b) - 2*len(axes[0])`. + + Args: + a: `Tensor` of type `float32` or `float64`. + b: `Tensor` with the same type as `a`. + axes: Either a scalar `N`, or a list or an `int32` `Tensor` of shape [2, k]. + If axes is a scalar, sum over the last N axes of a and the first N axes + of b in order. + If axes is a list or `Tensor` the first and second row contain the set of + unique integers specifying axes along which the contraction is computed, + for `a` and `b`, respectively. The number of axes for `a` and `b` must + be equal. + name: A name for the operation (optional). + + Returns: + A `Tensor` with the same type as `a`. + + Raises: + ValueError: If the shapes of `a`, `b`, and `axes` are incompatible. + IndexError: If the values in axes exceed the rank of the corresponding + tensor. + """ + + def _tensordot_reshape(a, axes, flipped=False): + """Helper method to perform transpose and reshape for contraction op. + + This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul` + using `array_ops.transpose` and `array_ops.reshape`. The method takes a + tensor and performs the correct transpose and reshape operation for a given + set of indices. It returns the reshaped tensor as well as a list of indices + necesary to reshape the tensor again after matrix multiplication. + + Args: + a: `Tensor`. + axes: List or `int32` `Tensor` of unique indices specifying valid axes of + `a`. + flipped: An optional `bool`. Defaults to `False`. If `True`, the method + assumes that `a` is the second argument in the contraction operation. + + Returns: + A pair `(reshaped_a, free_dims)` where `reshaped_a` is the tensor `a` + reshaped to allow contraction via `matmul` and `free_dims` is either a + list of integers or an `int32` `Tensor`, depending on if `axes` is a list + and the shape of `a` is fully defined. + """ + # TODO(b/33084409): Implement partial shape inference. + if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)): + shape_a = a.get_shape().as_list() + axes = [i if i >= 0 else i + len(shape_a) for i in axes] + free = [i for i in xrange(len(shape_a)) if i not in axes] + free_dims = [shape_a[i] for i in free] + prod_free = int(np.prod([shape_a[i] for i in free])) + prod_axes = int(np.prod([shape_a[i] for i in axes])) + perm = list(axes) + free if flipped else free + list(axes) + new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes] + reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape) + return reshaped_a, free_dims + else: + shape_a = array_ops.shape(a) + rank_a = array_ops.rank(a) + axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes") + axes = cast(axes >= 0, dtypes.int32) * axes + cast( + axes < 0, dtypes.int32) * (axes + rank_a) + free, _ = array_ops.setdiff1d(range(rank_a), axes) + free_dims = array_ops.gather(shape_a, free) + axes_dims = array_ops.gather(shape_a, axes) + prod_free_dims = reduce_prod(free_dims) + prod_axes_dims = reduce_prod(axes_dims) + perm = array_ops.concat(0, [axes_dims, free_dims]) + if flipped: + perm = array_ops.concat(0, [axes, free]) + new_shape = array_ops.stack([prod_axes_dims, prod_free_dims]) + else: + perm = array_ops.concat(0, [free, axes]) + new_shape = array_ops.stack([prod_free_dims, prod_axes_dims]) + reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape) + return reshaped_a, free_dims + + def _tensordot_axes(a, axes): + """Generates two sets of contraction axes for the two tensor arguments.""" + a_shape = a.get_shape() + if isinstance(axes, compat.integral_types): + if axes < 1: + raise ValueError("'axes' must be at least 1.") + if a_shape.ndims is not None: + return range(a_shape.ndims - axes, a_shape.ndims), range(axes) + else: + rank = array_ops.rank(a) + return (array_ops.range( + rank - axes, rank, dtype=dtypes.int32), array_ops.range( + rank, dtype=dtypes.int32)) + elif isinstance(axes, (list, tuple)): + if len(axes) != 2: + raise ValueError("'axes' must be an integer or have length 2.") + a_axes = axes[0] + b_axes = axes[1] + if len(a_axes) != len(b_axes): + raise ValueError( + "Different number of contraction axes 'a' and 'b', %s != %s.", + len(a_axes), len(b_axes)) + return a_axes, b_axes + else: + axes = ops.convert_to_tensor(axes, name="axes", dtype=dtypes.int32) + return axes[0], axes[1] + + with ops.name_scope(name, "Tensordot", [a, b, axes]) as name: + a = ops.convert_to_tensor(a, name="a") + b = ops.convert_to_tensor(b, name="b") + a_axes, b_axes = _tensordot_axes(a, axes) + a_reshape, a_free_dims = _tensordot_reshape(a, a_axes) + b_reshape, b_free_dims = _tensordot_reshape(b, b_axes, True) + ab_matmul = matmul(a_reshape, b_reshape) + if isinstance(a_free_dims, list) and isinstance(b_free_dims, list): + return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name) + else: + a_free_dims = ops.convert_to_tensor(a_free_dims) + b_free_dims = ops.convert_to_tensor(b_free_dims) + return array_ops.reshape( + ab_matmul, array_ops.concat(0, [a_free_dims, b_free_dims]), name=name) |