aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 1d27b036eb..37e2333560 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -44,7 +44,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test a finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 3})
for _ in range(3):
@@ -90,7 +90,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Take fewer than input size
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4):
@@ -136,7 +136,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Skip fewer than input size, we should skip
# the first 4 elements and then read the rest.
sess.run(init_op, feed_dict={count_placeholder: 4})
@@ -183,7 +183,7 @@ class SequenceDatasetTest(test.TestCase):
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
for _ in range(7 * 14):
results = sess.run(get_next)
@@ -199,7 +199,7 @@ class SequenceDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)