aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/tensordot_op_test.py
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/python/kernel_tests/tensordot_op_test.py
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/python/kernel_tests/tensordot_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py54
1 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index f1670a47f5..8ad29afd0a 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -66,7 +66,7 @@ class TensordotTest(test_lib.TestCase):
a = [[1, 2], [3, 4]]
b = [[1, 2], [3, 4]]
# Invalid static axes.
- for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]:
+ for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]:
with self.assertRaises(ValueError):
math_ops.tensordot(a, b, axes_value)
@@ -91,7 +91,7 @@ class TensordotTest(test_lib.TestCase):
# Test case for 11950
def test_valid_axis(self):
- for axes_value in [1, 2], [[1], [2]]:
+ for axes_value in [1, 2], [[1], [2]], [[], []], 0:
with self.test_session() as sess:
np_a = np.ones((3, 3))
np_b = np.array([2, 3, 1])[None, None]
@@ -105,29 +105,29 @@ class TensordotTest(test_lib.TestCase):
self.assertAllEqual(tf_ans, np_ans)
def test_partial_shape_inference(self):
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- axes = ([1], [0])
- output = math_ops.tensordot(a, b, axes)
- self.assertEqual(output.get_shape().ndims, None)
- a.set_shape([None, 2])
- b.set_shape([2, 3])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], None)
- self.assertEqual(output_shape[1], 3)
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- a.set_shape([2, 2])
- b.set_shape([2, None])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], 2)
- self.assertEqual(output_shape[1], None)
+ for axes in ([1], [0]), 1:
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ output = math_ops.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+ a.set_shape([None, 2])
+ b.set_shape([2, 3])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], None)
+ self.assertEqual(output_shape[1], 3)
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ a.set_shape([2, 2])
+ b.set_shape([2, None])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], 2)
+ self.assertEqual(output_shape[1], None)
def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
@@ -196,8 +196,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
b_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
- all_axes = [1]
- if a_np.ndim > 1:
+ all_axes = [0, 1]
+ if a_np.ndim > 2:
all_axes.append(a_np.ndim - 1)
for axes in all_axes:
np_ans = np.tensordot(a_np, b_np, axes=axes)