diff options
author | Michael Case <mikecase@google.com> | 2018-02-07 14:36:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-07 14:39:49 -0800 |
commit | d90054e7c0f41f4bab81df0548577a73b939a87a (patch) | |
tree | a15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/python/kernel_tests/tensordot_op_test.py | |
parent | 8461760f9f6cde8ed97507484d2a879140141032 (diff) |
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/python/kernel_tests/tensordot_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/tensordot_op_test.py | 54 |
1 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index f1670a47f5..8ad29afd0a 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -66,7 +66,7 @@ class TensordotTest(test_lib.TestCase): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4]] # Invalid static axes. - for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]: + for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]: with self.assertRaises(ValueError): math_ops.tensordot(a, b, axes_value) @@ -91,7 +91,7 @@ class TensordotTest(test_lib.TestCase): # Test case for 11950 def test_valid_axis(self): - for axes_value in [1, 2], [[1], [2]]: + for axes_value in [1, 2], [[1], [2]], [[], []], 0: with self.test_session() as sess: np_a = np.ones((3, 3)) np_b = np.array([2, 3, 1])[None, None] @@ -105,29 +105,29 @@ class TensordotTest(test_lib.TestCase): self.assertAllEqual(tf_ans, np_ans) def test_partial_shape_inference(self): - a = array_ops.placeholder(dtypes.float32) - b = array_ops.placeholder(dtypes.float32) - axes = ([1], [0]) - output = math_ops.tensordot(a, b, axes) - self.assertEqual(output.get_shape().ndims, None) - a.set_shape([None, 2]) - b.set_shape([2, 3]) - output = math_ops.tensordot(a, b, axes) - output_shape = output.get_shape() - self.assertEqual(output_shape.ndims, 2) - output_shape = output_shape.as_list() - self.assertEqual(output_shape[0], None) - self.assertEqual(output_shape[1], 3) - a = array_ops.placeholder(dtypes.float32) - b = array_ops.placeholder(dtypes.float32) - a.set_shape([2, 2]) - b.set_shape([2, None]) - output = math_ops.tensordot(a, b, axes) - output_shape = output.get_shape() - self.assertEqual(output_shape.ndims, 2) - output_shape = output_shape.as_list() - self.assertEqual(output_shape[0], 2) - self.assertEqual(output_shape[1], None) + for axes in ([1], [0]), 1: + a = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + output = math_ops.tensordot(a, b, axes) + self.assertEqual(output.get_shape().ndims, None) + a.set_shape([None, 2]) + b.set_shape([2, 3]) + output = math_ops.tensordot(a, b, axes) + output_shape = output.get_shape() + self.assertEqual(output_shape.ndims, 2) + output_shape = output_shape.as_list() + self.assertEqual(output_shape[0], None) + self.assertEqual(output_shape[1], 3) + a = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + a.set_shape([2, 2]) + b.set_shape([2, None]) + output = math_ops.tensordot(a, b, axes) + output_shape = output.get_shape() + self.assertEqual(output_shape.ndims, 2) + output_shape = output_shape.as_list() + self.assertEqual(output_shape[0], 2) + self.assertEqual(output_shape[1], None) def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): @@ -196,8 +196,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) b_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) - all_axes = [1] - if a_np.ndim > 1: + all_axes = [0, 1] + if a_np.ndim > 2: all_axes.append(a_np.ndim - 1) for axes in all_axes: np_ans = np.tensordot(a_np, b_np, axes=axes) |