aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/filter_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 4f2216f0a3..19944d389f 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -59,7 +59,7 @@ class FilterDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test that we can dynamically feed a different modulus value for each
# iterator.
def do_test(count_val, modulus_val):
@@ -84,7 +84,7 @@ class FilterDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
self.assertEqual(1, sess.run(get_next))
self.assertEqual(3, sess.run(get_next))
@@ -98,7 +98,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
if (i ** 2) % 2 == 0:
@@ -123,7 +123,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual(input_data[0], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -151,7 +151,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(5):
actual = sess.run(get_next)
@@ -169,7 +169,7 @@ class FilterDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
self.assertEqual((i, True), sess.run(get_next))
@@ -181,7 +181,7 @@ class FilterDatasetTest(test.TestCase):
lambda x: math_ops.equal(x % 2, 0))
iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
next_elements = [iterator.get_next() for iterator in iterators]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([0 for _ in range(10)], sess.run(next_elements))