aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 18:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 18:25:59 -0700
commit708b30f4cb82271bb28cb70a1e0c89a1933f5b64 (patch)
tree22470a9314f7f4225b6d08170a3d7ea91b0216a1 /tensorflow/contrib/optimizer_v2
parentd0cac47a767dd972516f75ce57f0d6185e3b6514 (diff)
Move from deprecated self.test_session() to self.session() when a graph is set.
self.test_session() has been deprecated in cl/208545396 as its behavior confuses readers of the test. Moving to self.session() instead. PiperOrigin-RevId: 209696110
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/adam_test.py2
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py6
2 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/adam_test.py b/tensorflow/contrib/optimizer_v2/adam_test.py
index d9ad58b0a6..1f079d9afc 100644
--- a/tensorflow/contrib/optimizer_v2/adam_test.py
+++ b/tensorflow/contrib/optimizer_v2/adam_test.py
@@ -152,7 +152,7 @@ class AdamOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 28a531dfec..e13b82d1d2 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -310,7 +310,7 @@ class CheckpointingTests(test.TestCase):
global_step=root.global_step)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
- with self.test_session(graph=ops.get_default_graph()) as session:
+ with self.session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
if checkpoint_path is None:
@@ -504,7 +504,7 @@ class CheckpointingTests(test.TestCase):
"""Saves after the first should not modify the graph."""
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
obj = tracking.Checkpointable()
@@ -522,7 +522,7 @@ class CheckpointingTests(test.TestCase):
"""Restores after the first should not modify the graph."""
with context.graph_mode():
graph = ops.Graph()
- with graph.as_default(), self.test_session(graph):
+ with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
obj = tracking.Checkpointable()