aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/gather_nd_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_nd_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/gather_nd_op_test.py10
1 files changed, 2 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index 5109ed98c9..af5e23c926 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -25,7 +25,6 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variables
@@ -186,9 +185,6 @@ class GatherNdTest(test.TestCase):
self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
- def assertIndexedSlices(self, t):
- self.assertIsInstance(t, ops.IndexedSlices)
-
def testUnknownIndices(self):
params = constant_op.constant([[0, 1, 2]])
indices = array_ops.placeholder(dtypes.int32)
@@ -237,8 +233,7 @@ class GatherNdTest(test.TestCase):
grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
with self.test_session(use_gpu=True):
- self.assertIndexedSlices(grads)
- self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
+ self.assertAllEqual(expected_grads, grads.eval())
def testGradientsRank3Elements(self):
indices = constant_op.constant(
@@ -289,8 +284,7 @@ class GatherNdTest(test.TestCase):
[0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]],
dtype=np.float64)
with self.test_session(use_gpu=True):
- self.assertIndexedSlices(grads)
- self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
+ self.assertAllEqual(expected_grads, grads.eval())
class GatherNdOpBenchmark(test.Benchmark):