diff options
author | 2016-05-18 19:25:50 -0800 | |
---|---|---|
committer | 2016-05-18 20:33:03 -0700 | |
commit | be8a3e2cf7738442f2b03c35b2ef9113b178855c (patch) | |
tree | 40e3825bd309df0971e76bbe6d6dee7fda4dfba8 /tensorflow/core/common_runtime/step_stats_collector.cc | |
parent | c4119befd020d41be1d753267ed238fd309931df (diff) |
Add the ability to return the cost model to the client as part of run metadata.
Change: 122696039
Diffstat (limited to 'tensorflow/core/common_runtime/step_stats_collector.cc')
-rw-r--r-- | tensorflow/core/common_runtime/step_stats_collector.cc | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 41ffcd1e1f..846814a1b3 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -24,24 +24,24 @@ StepStatsCollector::StepStatsCollector(StepStats* ss, CostModelManager* cost_model_manager) : step_stats_(ss), cost_model_manager_(cost_model_manager) {} -void StepStatsCollector::UpdateCostModel(const NodeExecStats* nt, - const Graph* graph, const Node* node) { +void StepStatsCollector::UpdateCostModelNode(const NodeExecStats* nt, + const Graph* graph, + const Node* node) { mutex_lock l(mu_); - if (cost_model_manager_ == nullptr) { - return; - } - CostModel* cm = cost_model_manager_->FindOrCreateCostModel(graph); - cm->RecordMaxExecutionTime(node, Microseconds(nt->op_end_rel_micros())); + if (cost_model_manager_ != nullptr) { + CostModel* cm = cost_model_manager_->FindOrCreateCostModel(graph); + cm->RecordMaxExecutionTime(node, Microseconds(nt->op_end_rel_micros())); - for (int i = 0; i < nt->output_size(); ++i) { - cm->RecordMaxMemSize(node, i, Bytes(nt->output(i) - .tensor_description() - .allocation_description() - .allocated_bytes())); - cm->RecordAliases(node, i, nt->output(i) - .tensor_description() - .allocation_description() - .allocation_id()); + for (int i = 0; i < nt->output_size(); ++i) { + cm->RecordMaxMemorySize(node, i, Bytes(nt->output(i) + .tensor_description() + .allocation_description() + .allocated_bytes())); + cm->RecordAllocationId(node, i, nt->output(i) + .tensor_description() + .allocation_description() + .allocation_id()); + } } } |