aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-06 13:12:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-06 14:17:56 -0700
commit35e23065d860f82020149544912314f152e42267 (patch)
tree771a7d0a86bcb82aa4ea7559bd642c982a816c0f /tensorflow
parentaba8beebab0b363f03492b3d5653ec14d148f3c3 (diff)
Don't assume the default graph in graph_actions.evaluate().
Change: 124176006
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py97
1 files changed, 49 insertions, 48 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index ef57d7ce36..7c765bc84c 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -192,28 +192,29 @@ def train(graph,
if not output_dir:
raise ValueError('Output directory should be non-empty.')
- global_step_tensor = contrib_variables.assert_or_get_global_step(
- graph, global_step_tensor)
- if global_step_tensor is None:
- raise ValueError('No "global_step" was provided or found in the graph.')
-
- summary_writer = (get_summary_writer(output_dir)
- if supervisor_is_chief else None)
-
- # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors.
- if not supervisor_is_chief:
- # monitors should run only on the chief.
- monitors = []
- elif not monitors:
- monitors = monitors_lib.get_default_monitors(
- loss_op=loss_op,
- summary_op=logging_ops.get_summary_op(),
- save_summary_steps=supervisor_save_summaries_steps,
- summary_writer=summary_writer)
-
- # Start monitors, can create graph parts.
- for monitor in monitors:
- monitor.begin(max_steps=max_steps)
+ with graph.as_default():
+ global_step_tensor = contrib_variables.assert_or_get_global_step(
+ graph, global_step_tensor)
+ if global_step_tensor is None:
+ raise ValueError('No "global_step" was provided or found in the graph.')
+
+ summary_writer = (get_summary_writer(output_dir)
+ if supervisor_is_chief else None)
+
+ # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors.
+ if not supervisor_is_chief:
+ # monitors should run only on the chief.
+ monitors = []
+ elif not monitors:
+ monitors = monitors_lib.get_default_monitors(
+ loss_op=loss_op,
+ summary_op=logging_ops.get_summary_op(),
+ save_summary_steps=supervisor_save_summaries_steps,
+ summary_writer=summary_writer)
+
+ # Start monitors, can create graph parts.
+ for monitor in monitors:
+ monitor.begin(max_steps=max_steps)
supervisor = tf_supervisor.Supervisor(
graph,
@@ -424,32 +425,32 @@ def evaluate(graph,
eval steps were run.
global_step: The global step this evaluation corresponds to.
"""
- global_step_tensor = contrib_variables.assert_or_get_global_step(
- graph, global_step_tensor)
-
- for key, value in eval_dict.items():
- if not summaries.is_summary_tag_unique(key):
- continue
- if isinstance(value, ops.Tensor):
- summaries.summarize_tensor(value, tag=key)
-
- # Create or get summary op, global_step and saver.
- summary_op = logging_ops.get_summary_op()
- saver = _get_saver()
- local_init_op = _get_local_init_op()
- ready_op = _get_ready_op()
-
- session_manager = session_manager_lib.SessionManager(
- local_init_op=local_init_op,
- ready_op=ready_op)
- session, initialized = session_manager.recover_session(
- master=supervisor_master,
- saver=saver,
- checkpoint_dir=checkpoint_path)
-
- # Start queue runners.
- coord = coordinator.Coordinator()
- threads = _start_queue_runners(session, coord)
+ with graph.as_default():
+ global_step_tensor = contrib_variables.assert_or_get_global_step(
+ graph, global_step_tensor)
+ for key, value in eval_dict.items():
+ if not summaries.is_summary_tag_unique(key):
+ continue
+ if isinstance(value, ops.Tensor):
+ summaries.summarize_tensor(value, tag=key)
+
+ # Create or get summary op, global_step and saver.
+ summary_op = logging_ops.get_summary_op()
+ saver = _get_saver()
+ local_init_op = _get_local_init_op()
+ ready_op = _get_ready_op()
+
+ session_manager = session_manager_lib.SessionManager(
+ local_init_op=local_init_op,
+ ready_op=ready_op)
+ session, initialized = session_manager.recover_session(
+ master=supervisor_master,
+ saver=saver,
+ checkpoint_dir=checkpoint_path)
+
+ # Start queue runners.
+ coord = coordinator.Coordinator()
+ threads = _start_queue_runners(session, coord)
with session:
if not initialized: