diff options
Diffstat (limited to 'tensorflow/python/data/kernel_tests/filter_dataset_op_test.py')
-rw-r--r-- | tensorflow/python/data/kernel_tests/filter_dataset_op_test.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 4f2216f0a3..19944d389f 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -59,7 +59,7 @@ class FilterDatasetTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Test that we can dynamically feed a different modulus value for each # iterator. def do_test(count_val, modulus_val): @@ -84,7 +84,7 @@ class FilterDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(get_next)) self.assertEqual(1, sess.run(get_next)) self.assertEqual(3, sess.run(get_next)) @@ -98,7 +98,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): if (i ** 2) % 2 == 0: @@ -123,7 +123,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual(input_data[0], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): @@ -151,7 +151,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(5): actual = sess.run(get_next) @@ -169,7 +169,7 @@ class FilterDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): self.assertEqual((i, True), sess.run(get_next)) @@ -181,7 +181,7 @@ class FilterDatasetTest(test.TestCase): lambda x: math_ops.equal(x % 2, 0)) iterators = [dataset.make_one_shot_iterator() for _ in range(10)] next_elements = [iterator.get_next() for iterator in iterators] - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([0 for _ in range(10)], sess.run(next_elements)) |