aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/kernel_tests/batch_dataset_op_test.py')
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index 89de55dd4f..c48708a2b9 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -82,7 +82,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -111,7 +111,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -131,7 +131,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -158,7 +158,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -188,7 +188,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
@@ -214,7 +214,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -262,7 +262,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.DataLossError):
sess.run(get_next)
@@ -318,7 +318,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
@@ -342,7 +342,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(
@@ -381,7 +381,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
(tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
padded_dataset = dataset.padded_batch(
2, padded_shapes=([None], [None]), padding_values=('', 0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
next_element = padded_dataset.make_one_shot_iterator().get_next()
sess.run(next_element)