diff options
author | 2018-09-17 13:24:29 -0700 | |
---|---|---|
committer | 2018-09-17 13:34:57 -0700 | |
commit | a768624f1d0ae3629caf5b9784b4b6911b881c18 (patch) | |
tree | f7581648c47b4ad95d10099f4485e5f41463f767 /tensorflow/contrib/data | |
parent | d7b4bf68dc80f1abf90bd6b857f079157028a861 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 213326581
Diffstat (limited to 'tensorflow/contrib/data')
7 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 8e368bf2bc..e2508de9e9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -742,7 +742,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -813,7 +813,7 @@ class RestructuredDatasetTest(test.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -837,7 +837,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -879,7 +879,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 83b723710c..25aea0393f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -116,7 +116,7 @@ class MapDefunTest(test.TestCase): elems2 = array_ops.placeholder(dtypes.int32) result = map_defun.map_defun(fn, [elems1, elems2], [dtypes.int32, dtypes.int32], [(), ()]) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesWithPredicateMatch( errors.InvalidArgumentError, "All inputs must have the same dimension 0."): @@ -225,7 +225,7 @@ class MapDefunTest(test.TestCase): c = constant_op.constant([1, 2, 3, 4, 5]) map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] - with self.test_session() as sess: + with self.cached_session() as sess: thread = self.checkedThread( self._assert_op_cancelled, args=(sess, map_defun_op)) thread.start() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py index bd7b50b902..d10da80442 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -31,7 +31,7 @@ class AssertNextDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(0, sess.run(get_next)) def testAssertNextInvalid(self): @@ -40,7 +40,7 @@ class AssertNextDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, "Asserted Whoops transformation at offset 0 but encountered " @@ -53,7 +53,7 @@ class AssertNextDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, "Asserted next 2 transformations but encountered only 1."): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index dde115925e..e75edf6086 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -200,7 +200,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): optimization.optimize(["filter_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(5): r = map_function(x) filtered = False diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py index 2b3ac85924..3b62a7e468 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py @@ -40,7 +40,7 @@ class ModelDatasetTest(test.TestCase): get_next = iterator.get_next() deltas = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(5): sess.run(get_next.op) for _ in range(100): @@ -64,7 +64,7 @@ class ModelDatasetTest(test.TestCase): get_next = iterator.get_next() deltas = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(5): sess.run(get_next.op) for _ in range(1000): @@ -92,7 +92,7 @@ class ModelDatasetTest(test.TestCase): get_next = iterator.get_next() deltas = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(5): sess.run(get_next.op) for _ in range(10): @@ -119,7 +119,7 @@ class ModelDatasetTest(test.TestCase): get_next = iterator.get_next() deltas = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(5): sess.run(get_next.op) for _ in range(1000): @@ -164,7 +164,7 @@ class ModelDatasetTest(test.TestCase): get_next = iterator.get_next() deltas = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(5): sess.run(get_next) for _ in range(100): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py index 909da5aee0..a3fb824ce9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -38,7 +38,7 @@ class OptimizeDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -51,7 +51,7 @@ class OptimizeDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -64,7 +64,7 @@ class OptimizeDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -76,7 +76,7 @@ class OptimizeDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(get_next) def testOptimizationLargeInputFromTensor(self): @@ -87,7 +87,7 @@ class OptimizeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) sess.run(get_next) @@ -99,7 +99,7 @@ class OptimizeDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index e25570c5ad..719ce2e3fe 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -40,7 +40,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) expected_sum = 0.0 for i in range(100): @@ -65,7 +65,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) @@ -84,7 +84,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(100): self.assertAllEqual( @@ -109,7 +109,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: for j in range(5): sess.run(iterator.initializer) for i in range(100): @@ -127,7 +127,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) @@ -144,7 +144,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) @@ -168,7 +168,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) @@ -188,7 +188,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): next_element = iterator_0.get_next() + iterator_1.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([iterator_0.initializer, iterator_1.initializer]) for i in range(100): self.assertEqual(i * 2, sess.run(next_element)) |