aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py54
1 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 67242fecfe..8e368bf2bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
@@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
@@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
@@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
@@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
@@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
@@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
@@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
@@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
@@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
@@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
@@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
@@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
@@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
@@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
@@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)