diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-10 14:44:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-10 15:12:57 -0700 |
commit | a5752eb9cb266262f3b7a289f12c21e268b3041d (patch) | |
tree | ffbf7b539ce8ca6ab4bdb1fa65f4f293ff25f371 | |
parent | 6d3af1df20f611641665f63e8bb49a875823432b (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212338134
22 files changed, 155 insertions, 155 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 67242fecfe..8e368bf2bc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize with an input tensor of incompatible rank. sess.run(init_op, feed_dict={input_tensor: [[1]]}) with self.assertRaisesRegexp(errors.InvalidArgumentError, @@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i,) * 3, sess.run(op)) @@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) @@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): st_row = sess.run(next_element) self.assertEqual([i], st_row.indices) @@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): dense_elem, st_row = sess.run(next_element) self.assertEqual(i, dense_elem) @@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i,),) * 3, sess.run(op)) @@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) @@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Mismatch in the 0th dimension. sess.run( iterator.initializer, @@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Batch of a finite input, where the batch_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 28, batch_size: 14}) @@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) if not drop_remainder: @@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_one_shot_iterator()) self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) self.assertAllEqual([[64], [81]], sess.run(next_element)) @@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(4): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) .make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) @@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "number of elements does not match"): @@ -659,7 +659,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(3): sess.run(get_next) @@ -686,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=10)).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) if threshold % 10 != 0: @@ -718,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) @@ -784,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -908,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 2022c1f2bd..293be2bd06 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for expected in values: got = sess.run(get_next) self.assertEqual(got, expected) @@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase): self.assertIs(None, dataset.output_shapes[1].ndims) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual([0] * (2**i), x) self.assertAllEqual(np.array(1, ndmin=i), y) @@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase): (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual(x, np.asarray([x for x in range(10)])) self.assertEqual(y, 45) @@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # The input is infinite, so this test demonstrates that: # 1. We produce output without having to consume the entire input, @@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) @@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -376,7 +376,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) which_bucket, bucketed_values = sess.run(get_next) @@ -411,7 +411,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches (one containing even values, one containing odds) @@ -482,7 +482,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches ([0, 2, ...] and [64, 66, ...]) @@ -515,7 +515,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.OutOfRangeError): batches = 0 @@ -556,7 +556,7 @@ class BucketBySequenceLength(test.TestCase): element_len, boundaries, batch_sizes)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(4): batches.append(sess.run(batch)) @@ -600,7 +600,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(3): batches.append(sess.run(batch)) @@ -637,7 +637,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(5): batches.append(sess.run(batch)) diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 9020a499c4..eb110324d1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for _ in range(100): for i in range(10): @@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: freqs = np.zeros([num_datasets]) for _ in range(num_samples): freqs[sess.run(next_element)] += 1 @@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in choice_array: self.assertEqual(words[i], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index e6883d53e0..f3968cdc15 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) - with self.test_session() as sess: + with self.cached_session() as sess: if error is None: dense_val, sparse_val = sess.run( element, feed_dict={ @@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): dataset = dataset_ops.Dataset.range(stop_t) element = get_single_element.reduce_dataset(dataset, sum_reducer) - with self.test_session() as sess: + with self.cached_session() as sess: value = sess.run(element, feed_dict={stop_t: stop}) self.assertEqual(stop * (stop - 1) / 2, value) diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py index db2ab815ee..9c508d686d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase): get_op = gen_dataset_ops.indexed_dataset_get( handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialize) self.assertEqual([3], sess.run(get_op, feed_dict={index: 3})) def testIdentityIndexedDataset(self): ds = indexed_dataset_ops.IdentityIndexedDataset(16) materialized = ds.materialize() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialized.initializer) placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) for i in range(16): @@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase): ds = indexed_dataset_ops.IdentityIndexedDataset(16) itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(itr.initializer) for i in range(16): output = sess.run(n) diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 7a3215f6cc..b9e74dfddb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedRagged(self): # Tests a sequence with wildly different elements per iterator. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) def _testEmptyInput(self, sloppy=False): - with self.test_session() as sess: + with self.cached_session() as sess: # Empty input. self._clear_coordination_events() sess.run( @@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): # Non-empty input leading to empty output. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) def testBlockLengthWithContentionSloppy(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testEarlyExit(self, sloppy=False): # Exiting without consuming all input should not block - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: output_values = [] for _ in range(30): output_values.append(sess.run(iterator.get_next())) @@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): for j in range(2): @@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(get_next) def testErrorsInOutputFn(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): next_element = iterator.get_next() results = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): elements = [] sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 7bc582ebaa..1cc5ddc9a2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(num_repeats): # Dataset is repeated. for i in range(10): # 10 records. diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 55c9ac68dd..e8519381d6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # All of the files are present. sess.run(init_op) for filename in filenames: diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py index f6c4a984b8..c4623bca73 100644 --- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase): expected_values=None, expected_err=None): - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 361fe0dd39..0166ba0d44 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): destroy_op = resource_variable_ops.destroy_resource_op( buffer_resource_handle, ignore_lookup_error=True) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([b"a"], sess.run(prefetch_op)) self.assertEqual([b"b"], sess.run(prefetch_op)) self.assertEqual([b"c"], sess.run(prefetch_op)) @@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase): iterator = back_to_cpu_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase): elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() - with self.test_session() as sess: + with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 592642da0c..db8fe6aa1b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase): self.assertEqual([tensor_shape.TensorShape([])] * 3, [t.shape for t in get_next[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next)) self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next)) @@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase): .make_one_shot_iterator()) negative_get_next = negative_iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(get_next)) self.assertEqual(3 + 4, sess.run(get_next)) self.assertEqual(3 + 2 * 4, sess.run(get_next)) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index fd00cdc5c6..ed75b27a44 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -116,7 +116,7 @@ class ReadBatchFeaturesTest( init_op = iterator.initializer next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( range(self._num_files), 2, 10): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index c5cfddb72b..16b1441baa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): class_func=lambda c, _: c, seed=27)).make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] while len(returned) < 4000: returned.append(sess.run(get_next)) @@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 42cada0b97..dde678bd54 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase): start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase): make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): (longer_vector_val, larger_rank_val), _ = sess.run(next_element) self.assertAllEqual([0] * (2**i), longer_vector_val) diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 077abd6b30..440e48db30 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase): def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): get_next = ds_fn().make_one_shot_iterator().get_next() outputs = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(num_outputs): outputs.append(sess.run(get_next)) if verify_exhausted: diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 6b3e8e9f6e..90d18dca2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): window_stride=window_stride_t)).make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Slide: 1st batch. actual = sess.run(get_next) @@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 2c2cfbebff..52823d3fca 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string), 2) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): # Run twice to verify statelessness of db operations. sess.run( init_op, @@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetJoinQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetNullTerminator(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetReuseSqlDataset(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadEmptyResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidDriverName(self): init_op = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string))[0] - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidColumnName(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfQueryWithSyntaxError(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfInsertQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int8` tensor. def testReadResultSetInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt8NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int8` tensor. def testReadResultSetInt8MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int16` tensor. def testReadResultSetInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt16NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int16` tensor. def testReadResultSetInt16MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int32` tensor. def testReadResultSetInt32(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # table and place it in an `int32` tensor. def testReadResultSetInt32VarCharColumnAsInt(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in an `int64` tensor. def testReadResultSetInt64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in a `uint8` tensor. def testReadResultSetUInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint8` tensors. def testReadResultSetUInt8MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in a `uint16` tensor. def testReadResultSetUInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint16` tensors. def testReadResultSetUInt16MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # in `bool` tensors. def testReadResultSetBool(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # from a SQLite database table and place it as `True` in a `bool` tensor. def testReadResultSetBoolNotZeroOrOne(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64OverlyPrecise(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1d70b16041..1def07179a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase): # TODO(rachelim): support sparse tensor outputs next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: while True: try: op1 = sess.run(next1) @@ -54,7 +54,7 @@ class DatasetTestBase(test.TestCase): replacements=None): next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: try: sess.run(next1) raise ValueError( diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 4b08ec759d..8d335e87d5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -69,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) thread_ids = [] try: diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index d79a842e7a..f994c8563f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_case, expected in test_cases: current_test_case = test_case sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index ff4d9b3260..6eaa0b1959 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -92,7 +92,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).apply( grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run(self._structuredElement(structure, shape, dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual) @@ -128,7 +128,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredElement(structure, np.concatenate( ([5], shape), axis=0), dtype)) @@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredElement(None, np.concatenate(([5], shape), axis=0), @@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredSparseElement(structure, np.concatenate(([5], shape), axis=0), @@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredSparseElement(None, @@ -321,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( self._structuredElement( @@ -352,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( @@ -380,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -458,7 +458,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shapes, dtype).apply(grouping.window_dataset( len(shapes))).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredRaggedSparseElement(structure, shapes, dtype, padded_shape)) @@ -489,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, @@ -516,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py index c603ecc5ab..867ee2ba37 100644 --- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase): return os.path.join(self.get_temp_dir(), "tf_record.out.txt") def testWrite(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ self.filename: self._createFile(), @@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteZLIB(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ @@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteGZIP(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ |