aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_stitch_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index 49b9569e2b..3a1036e52a 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -252,7 +252,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
# GPU version unit tests
def testScalarGPU(self):
- with self.test_session():
+ with self.cached_session():
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
for step in -1, 1:
@@ -263,7 +263,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
self.assertEqual([2], stitched_t.get_shape().as_list())
def testHigherRankGPU(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = [
constant_op.constant(6),
constant_op.constant([4, 1]),