aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_nd_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f2f3023469..86e063cb36 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -294,7 +294,7 @@ class StatefulScatterNdTest(test.TestCase):
self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
ref.initializer.run()
self.assertAllEqual(expected_result, scatter_update.eval())
@@ -409,7 +409,7 @@ class ScatterNdTest(test.TestCase):
expected = np.array([b"", b"one", b"", b"three", b"four",
b"", b"", b"seven"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertAllEqual(expected, result)
@@ -420,7 +420,7 @@ class ScatterNdTest(test.TestCase):
dtype=dtypes.string)
expected = np.array([b"", b"", b"", b"bb", b"a", b"", b"", b"c"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertAllEqual(expected, result)
@@ -432,7 +432,7 @@ class ScatterNdTest(test.TestCase):
expected = [np.array([b"", b"", b"", b"bc", b"a", b"", b"", b"d"]),
np.array([b"", b"", b"", b"cb", b"a", b"", b"", b"d"])]
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertTrue(np.array_equal(result, expected[0]) or
np.array_equal(result, expected[1]))
@@ -451,7 +451,7 @@ class ScatterNdTest(test.TestCase):
scatter = self.scatter_nd(indices, updates, shape)
self.assertAllEqual(scatter.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_result, scatter.eval())
def testUndefinedIndicesShape(self):
@@ -486,7 +486,7 @@ class ScatterNdTest(test.TestCase):
updates = array_ops.placeholder(dtypes.int32, shape=None)
shape = constant_op.constant([0, 3, 2], dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
"Indices and updates specified for empty output"):
self.scatter_nd(indices, updates, shape).eval(feed_dict={
@@ -500,7 +500,7 @@ class ScatterNdTest(test.TestCase):
shape = constant_op.constant([0], dtypes.int32)
scatter = self.scatter_nd(indices, updates, shape)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(scatter.eval().size, 0)
def testRank3InvalidShape1(self):
@@ -531,7 +531,7 @@ class ScatterNdTest(test.TestCase):
[outputs], [updates, input_], [grad_vals])
expected_updates_grad = np.array([1, 4], dtype=np.float64)
expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -548,7 +548,7 @@ class ScatterNdTest(test.TestCase):
[outputs], [updates, input_], [grad_vals])
expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -570,7 +570,7 @@ class ScatterNdTest(test.TestCase):
[[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
expected_input_grad = np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -607,7 +607,7 @@ class ScatterNdTest(test.TestCase):
[[[[1, 2], [3, 4]]]],
[[[[5, 6], [7, 8]]]]
]]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -616,33 +616,33 @@ class ScatterNdTest(test.TestCase):
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
shape = [1]
- with self.test_session():
+ with self.cached_session():
val = self.scatter_nd(indices, values, shape).eval()
self.assertAllClose([np.sum(values)], val)
def testSmokeScatterNdBatch2DSliceDim2(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
values = array_ops.zeros([3, 5, 7])
shape = [4, 6, 7]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim2(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
values = array_ops.zeros([0, 7])
shape = [4, 6, 7]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 2, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]