aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py17
-rw-r--r--tensorflow/python/ops/math_ops.py4
2 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index 71230ba000..f375157287 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -84,6 +85,22 @@ class TensordotTest(test_lib.TestCase):
b_ph: b,
axes_ph: axes_value})
+ # Test case for 11950
+ def test_valid_axis(self):
+ for axes_value in [1, 2], [[1], [2]]:
+ with self.test_session() as sess:
+ np_a = np.ones((3,3))
+ np_b = np.array([2, 3, 1])[None, None]
+ np_ans = np.tensordot(np_a, np_b, axes_value)
+
+ tf_a = array_ops.ones((3,3), dtype=dtypes.float32)
+ tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None]
+ tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value).eval()
+
+ self.assertAllEqual(tf_ans.shape, np_ans.shape)
+ self.assertAllEqual(tf_ans, np_ans)
+
+
def test_partial_shape_inference(self):
a = array_ops.placeholder(dtypes.float32)
b = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 3e91ec0684..ae354e92d2 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2448,6 +2448,10 @@ def tensordot(a, b, axes, name=None):
raise ValueError("'axes' must be an integer or have length 2.")
a_axes = axes[0]
b_axes = axes[1]
+ if isinstance(a_axes, compat.integral_types) and \
+ isinstance(b_axes, compat.integral_types):
+ a_axes = [a_axes]
+ b_axes = [b_axes]
if len(a_axes) != len(b_axes):
raise ValueError(
"Different number of contraction axes 'a' and 'b', %s != %s.",