aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 14:37:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:04:14 -0700
commitb828f89263e054bfa7c7a808cab1506834ab906d (patch)
treee31816a6850d177306f19ee8670e0836060fcfc9 /tensorflow/contrib/eager
parentacf0ee82092727afc2067316982407cf5e496f75 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212336464
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/evaluator_test.py4
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py4
2 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py
index 7d2274db9b..48d093e075 100644
--- a/tensorflow/contrib/eager/python/evaluator_test.py
+++ b/tensorflow/contrib/eager/python/evaluator_test.py
@@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"].numpy())
def testDatasetGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
@@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"])
def testWriteSummariesGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
training_util.get_or_create_global_step()
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index dcc7b71d79..9d2d172752 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -216,7 +216,7 @@ class MetricsTest(test.TestCase):
self.assertEqual(m1.numer.name, "has_space/numer:0")
def testGraphWithPlaceholder(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
@@ -309,7 +309,7 @@ class MetricsTest(test.TestCase):
self.assertTrue(old_numer is m.numer)
def testMetricsChain(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
m1 = metrics.Mean()
m2 = metrics.Mean(name="m2")
update_m2 = m2(3.0)