diff options
author | 2018-09-10 14:37:06 -0700 | |
---|---|---|
committer | 2018-09-10 15:04:14 -0700 | |
commit | b828f89263e054bfa7c7a808cab1506834ab906d (patch) | |
tree | e31816a6850d177306f19ee8670e0836060fcfc9 /tensorflow/contrib/eager | |
parent | acf0ee82092727afc2067316982407cf5e496f75 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212336464
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/evaluator_test.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/metrics_test.py | 4 |
2 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 7d2274db9b..48d093e075 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"].numpy()) def testDatasetGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) init_op, call_op, results_op = e.evaluate_on_dataset(ds) @@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"]) def testWriteSummariesGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) training_util.get_or_create_global_step() diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index dcc7b71d79..9d2d172752 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -216,7 +216,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.numer.name, "has_space/numer:0") def testGraphWithPlaceholder(self): - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) accumulate = m(p) @@ -309,7 +309,7 @@ class MetricsTest(test.TestCase): self.assertTrue(old_numer is m.numer) def testMetricsChain(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): m1 = metrics.Mean() m2 = metrics.Mean(name="m2") update_m2 = m2(3.0) |