diff options
-rw-r--r-- | tensorflow/python/kernel_tests/tensordot_op_test.py | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 4 |
2 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 71230ba000..f375157287 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -84,6 +85,22 @@ class TensordotTest(test_lib.TestCase): b_ph: b, axes_ph: axes_value}) + # Test case for 11950 + def test_valid_axis(self): + for axes_value in [1, 2], [[1], [2]]: + with self.test_session() as sess: + np_a = np.ones((3,3)) + np_b = np.array([2, 3, 1])[None, None] + np_ans = np.tensordot(np_a, np_b, axes_value) + + tf_a = array_ops.ones((3,3), dtype=dtypes.float32) + tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] + tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval() + + self.assertAllEqual(tf_ans.shape, np_ans.shape) + self.assertAllEqual(tf_ans, np_ans) + + def test_partial_shape_inference(self): a = array_ops.placeholder(dtypes.float32) b = array_ops.placeholder(dtypes.float32) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 3e91ec0684..ae354e92d2 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2448,6 +2448,10 @@ def tensordot(a, b, axes, name=None): raise ValueError("'axes' must be an integer or have length 2.") a_axes = axes[0] b_axes = axes[1] + if isinstance(a_axes, compat.integral_types) and \ + isinstance(b_axes, compat.integral_types): + a_axes = [a_axes] + b_axes = [b_axes] if len(a_axes) != len(b_axes): raise ValueError( "Different number of contraction axes 'a' and 'b', %s != %s.", |