aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:44:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:12:57 -0700
commita5752eb9cb266262f3b7a289f12c21e268b3041d (patch)
treeffbf7b539ce8ca6ab4bdb1fa65f4f293ff25f371
parent6d3af1df20f611641665f63e8bb49a875823432b (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212338134
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py54
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py32
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py36
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py22
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py6
22 files changed, 155 insertions, 155 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 67242fecfe..8e368bf2bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
@@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
@@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
@@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
@@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
@@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
@@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
@@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
@@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
@@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
@@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
@@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
@@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
@@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
@@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
@@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 2022c1f2bd..293be2bd06 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
@@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase):
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
@@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase):
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
@@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
@@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
@@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -376,7 +376,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
@@ -411,7 +411,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
@@ -482,7 +482,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
@@ -515,7 +515,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
@@ -556,7 +556,7 @@ class BucketBySequenceLength(test.TestCase):
element_len, boundaries, batch_sizes))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(4):
batches.append(sess.run(batch))
@@ -600,7 +600,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
@@ -637,7 +637,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9020a499c4..eb110324d1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(100):
for i in range(10):
@@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freqs = np.zeros([num_datasets])
for _ in range(num_samples):
freqs[sess.run(next_element)] += 1
@@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in choice_array:
self.assertEqual(words[i], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index e6883d53e0..f3968cdc15 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if error is None:
dense_val, sparse_val = sess.run(
element, feed_dict={
@@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(stop_t)
element = get_single_element.reduce_dataset(dataset, sum_reducer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sess.run(element, feed_dict={stop_t: stop})
self.assertEqual(stop * (stop - 1) / 2, value)
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index db2ab815ee..9c508d686d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase):
get_op = gen_dataset_ops.indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(materialize)
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
def testIdentityIndexedDataset(self):
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
materialized = ds.materialize()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(materialized.initializer)
placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
for i in range(16):
@@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase):
ds = indexed_dataset_ops.IdentityIndexedDataset(16)
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(itr.initializer)
for i in range(16):
output = sess.run(n)
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 7a3215f6cc..b9e74dfddb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
def _testEmptyInput(self, sloppy=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Empty input.
self._clear_coordination_events()
sess.run(
@@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
# Non-empty input leading to empty output.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
# Mixture of non-empty and empty interleaved datasets.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(self.next_element)
def testBlockLengthWithContentionSloppy(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_values = []
for _ in range(30):
output_values.append(sess.run(iterator.get_next()))
@@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
@@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(get_next)
def testErrorsInOutputFn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
next_element = iterator.get_next()
results = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2):
elements = []
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 7bc582ebaa..1cc5ddc9a2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_repeats): # Dataset is repeated.
for i in range(10): # 10 records.
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index 55c9ac68dd..e8519381d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index f6c4a984b8..c4623bca73 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase):
expected_values=None,
expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 361fe0dd39..0166ba0d44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
@@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = back_to_cpu_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase):
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 592642da0c..db8fe6aa1b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase):
self.assertEqual([tensor_shape.TensorShape([])] * 3,
[t.shape for t in get_next[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
@@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase):
.make_one_shot_iterator())
negative_get_next = negative_iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index fd00cdc5c6..ed75b27a44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -116,7 +116,7 @@ class ReadBatchFeaturesTest(
init_op = iterator.initializer
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index c5cfddb72b..16b1441baa 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
class_func=lambda c, _: c,
seed=27)).make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
while len(returned) < 4000:
returned.append(sess.run(get_next))
@@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
@@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 42cada0b97..dde678bd54 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase):
start, make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase):
make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
self.assertAllEqual([0] * (2**i), longer_vector_val)
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 077abd6b30..440e48db30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase):
def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
get_next = ds_fn().make_one_shot_iterator().get_next()
outputs = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(num_outputs):
outputs.append(sess.run(get_next))
if verify_exhausted:
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 6b3e8e9f6e..90d18dca2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
window_stride=window_stride_t)).make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Slide: 1st batch.
actual = sess.run(get_next)
@@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index 2c2cfbebff..52823d3fca 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2): # Run twice to verify statelessness of db operations.
sess.run(
init_op,
@@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetJoinQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetNullTerminator(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetReuseSqlDataset(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadEmptyResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidDriverName(self):
init_op = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidColumnName(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfQueryWithSyntaxError(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfInsertQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int8` tensor.
def testReadResultSetInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt8NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int16` tensor.
def testReadResultSetInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt16NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int32` tensor.
def testReadResultSetInt32(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# table and place it in an `int32` tensor.
def testReadResultSetInt32VarCharColumnAsInt(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in an `int64` tensor.
def testReadResultSetInt64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in a `uint8` tensor.
def testReadResultSetUInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint8` tensors.
def testReadResultSetUInt8MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in a `uint16` tensor.
def testReadResultSetUInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint16` tensors.
def testReadResultSetUInt16MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# in `bool` tensors.
def testReadResultSetBool(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# from a SQLite database table and place it as `True` in a `bool` tensor.
def testReadResultSetBoolNotZeroOrOne(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64OverlyPrecise(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
index 1d70b16041..1def07179a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase):
# TODO(rachelim): support sparse tensor outputs
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
while True:
try:
op1 = sess.run(next1)
@@ -54,7 +54,7 @@ class DatasetTestBase(test.TestCase):
replacements=None):
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
try:
sess.run(next1)
raise ValueError(
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 4b08ec759d..8d335e87d5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -69,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
thread_ids = []
try:
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index d79a842e7a..f994c8563f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index ff4d9b3260..6eaa0b1959 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -92,7 +92,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
@@ -128,7 +128,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredElement(structure, np.concatenate(
([5], shape), axis=0), dtype))
@@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredElement(None, np.concatenate(([5], shape), axis=0),
@@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredSparseElement(structure,
np.concatenate(([5], shape), axis=0),
@@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredSparseElement(None,
@@ -321,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping.window_dataset(len(shapes))).apply(
grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
self._structuredElement(
@@ -352,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
@@ -380,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -458,7 +458,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shapes, dtype).apply(grouping.window_dataset(
len(shapes))).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredRaggedSparseElement(structure, shapes, dtype,
padded_shape))
@@ -489,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected = sess.run(
self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
@@ -516,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index c603ecc5ab..867ee2ba37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase):
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
def testWrite(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer, feed_dict={
self.filename: self._createFile(),
@@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteZLIB(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
@@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteGZIP(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={