diff options
Diffstat (limited to 'tensorflow/core/common_runtime/direct_session.cc')
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 953c4180fd..f00f5ffd8f 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -397,6 +397,9 @@ Status DirectSession::Run(const RunOptions& run_options, ExecutorsAndKeys* executors_and_keys; RunStateArgs run_state_args; + Executor::Args args; + args.step_id = step_id_counter_.fetch_add(1); + // EXPERIMENTAL: Options that allow the client to insert nodes into partition // graphs for debugging. if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { @@ -407,10 +410,15 @@ Status DirectSession::Run(const RunOptions& run_options, TF_RETURN_IF_ERROR( GetOrCreateExecutors(pool, input_tensor_names, output_names, target_nodes, &executors_and_keys, &run_state_args)); + const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); + + if (run_state_args.debugger_state) { + TF_RETURN_IF_ERROR(run_state_args.debugger_state->PublishDebugMetadata( + run_options.debug_options().global_step(), args.step_id, + executor_step_count, input_tensor_names, output_names, target_nodes)); + } // Create a run state and start execution. - Executor::Args args; - args.step_id = step_id_counter_.fetch_add(1); RunState run_state(args.step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); CancellationManager step_cancellation_manager; @@ -450,8 +458,7 @@ Status DirectSession::Run(const RunOptions& run_options, options_.config.graph_options().build_cost_model(); const int64 build_cost_model_after = options_.config.graph_options().build_cost_model_after(); - int measure_step_count = - executors_and_keys->step_count - build_cost_model_after; + int measure_step_count = executor_step_count - build_cost_model_after; if (measure_step_count >= 0) { update_cost_model = ((measure_step_count + 1) % build_cost_model_every == 0); @@ -527,7 +534,6 @@ Status DirectSession::Run(const RunOptions& run_options, // Build and return the cost model as instructed. mutex_lock l(executor_lock_); - ++executors_and_keys->step_count; if (update_cost_model) { // Build the cost model std::unordered_map<string, const Graph*> device_to_graph; |