diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-06-06 13:12:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-06 14:17:56 -0700 |
commit | 35e23065d860f82020149544912314f152e42267 (patch) | |
tree | 771a7d0a86bcb82aa4ea7559bd642c982a816c0f /tensorflow | |
parent | aba8beebab0b363f03492b3d5653ec14d148f3c3 (diff) |
Don't assume the default graph in graph_actions.evaluate().
Change: 124176006
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions.py | 97 |
1 files changed, 49 insertions, 48 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index ef57d7ce36..7c765bc84c 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -192,28 +192,29 @@ def train(graph, if not output_dir: raise ValueError('Output directory should be non-empty.') - global_step_tensor = contrib_variables.assert_or_get_global_step( - graph, global_step_tensor) - if global_step_tensor is None: - raise ValueError('No "global_step" was provided or found in the graph.') - - summary_writer = (get_summary_writer(output_dir) - if supervisor_is_chief else None) - - # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors. - if not supervisor_is_chief: - # monitors should run only on the chief. - monitors = [] - elif not monitors: - monitors = monitors_lib.get_default_monitors( - loss_op=loss_op, - summary_op=logging_ops.get_summary_op(), - save_summary_steps=supervisor_save_summaries_steps, - summary_writer=summary_writer) - - # Start monitors, can create graph parts. - for monitor in monitors: - monitor.begin(max_steps=max_steps) + with graph.as_default(): + global_step_tensor = contrib_variables.assert_or_get_global_step( + graph, global_step_tensor) + if global_step_tensor is None: + raise ValueError('No "global_step" was provided or found in the graph.') + + summary_writer = (get_summary_writer(output_dir) + if supervisor_is_chief else None) + + # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors. + if not supervisor_is_chief: + # monitors should run only on the chief. + monitors = [] + elif not monitors: + monitors = monitors_lib.get_default_monitors( + loss_op=loss_op, + summary_op=logging_ops.get_summary_op(), + save_summary_steps=supervisor_save_summaries_steps, + summary_writer=summary_writer) + + # Start monitors, can create graph parts. + for monitor in monitors: + monitor.begin(max_steps=max_steps) supervisor = tf_supervisor.Supervisor( graph, @@ -424,32 +425,32 @@ def evaluate(graph, eval steps were run. global_step: The global step this evaluation corresponds to. """ - global_step_tensor = contrib_variables.assert_or_get_global_step( - graph, global_step_tensor) - - for key, value in eval_dict.items(): - if not summaries.is_summary_tag_unique(key): - continue - if isinstance(value, ops.Tensor): - summaries.summarize_tensor(value, tag=key) - - # Create or get summary op, global_step and saver. - summary_op = logging_ops.get_summary_op() - saver = _get_saver() - local_init_op = _get_local_init_op() - ready_op = _get_ready_op() - - session_manager = session_manager_lib.SessionManager( - local_init_op=local_init_op, - ready_op=ready_op) - session, initialized = session_manager.recover_session( - master=supervisor_master, - saver=saver, - checkpoint_dir=checkpoint_path) - - # Start queue runners. - coord = coordinator.Coordinator() - threads = _start_queue_runners(session, coord) + with graph.as_default(): + global_step_tensor = contrib_variables.assert_or_get_global_step( + graph, global_step_tensor) + for key, value in eval_dict.items(): + if not summaries.is_summary_tag_unique(key): + continue + if isinstance(value, ops.Tensor): + summaries.summarize_tensor(value, tag=key) + + # Create or get summary op, global_step and saver. + summary_op = logging_ops.get_summary_op() + saver = _get_saver() + local_init_op = _get_local_init_op() + ready_op = _get_ready_op() + + session_manager = session_manager_lib.SessionManager( + local_init_op=local_init_op, + ready_op=ready_op) + session, initialized = session_manager.recover_session( + master=supervisor_master, + saver=saver, + checkpoint_dir=checkpoint_path) + + # Start queue runners. + coord = coordinator.Coordinator() + threads = _start_queue_runners(session, coord) with session: if not initialized: |