diff options
author | Andrew Harp <andrew.harp@gmail.com> | 2016-04-18 15:20:57 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-18 16:31:49 -0700 |
commit | 5c9e7d39187321f27cfc1ebb32064f49bb337313 (patch) | |
tree | d47533ecfbe8bb6bd5f1f509b03877d5df3fc48d /tensorflow/core/util/stat_summarizer.cc | |
parent | 2092fb45f773db7ff0cfc090cc245594999e4999 (diff) |
Have StatSummarizer print OP types as well as name.
Change: 120172505
Diffstat (limited to 'tensorflow/core/util/stat_summarizer.cc')
-rw-r--r-- | tensorflow/core/util/stat_summarizer.cc | 47 |
1 files changed, 36 insertions, 11 deletions
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc index 4586a06995..8a3c92c92b 100644 --- a/tensorflow/core/util/stat_summarizer.cc +++ b/tensorflow/core/util/stat_summarizer.cc @@ -21,6 +21,7 @@ limitations under the License. #include <sstream> #include <string> +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -28,6 +29,15 @@ limitations under the License. namespace tensorflow { +StatSummarizer::StatSummarizer(const tensorflow::GraphDef& tensorflow_graph) { + LOG(INFO) << "StatSummarizer found " << tensorflow_graph.node_size() + << " nodes"; + for (const auto& node : tensorflow_graph.node()) { + nodes_.push_back(node.name()); + node_types_[node.name()] = node.op(); + } +} + void StatSummarizer::ProcessStepStats(const StepStats& step_stats) { ++num_runs_; int64 curr_total = 0; @@ -48,17 +58,21 @@ void StatSummarizer::ProcessStepStats(const StepStats& step_stats) { void StatSummarizer::PrintHeaders() { std::stringstream stream; stream << std::setw(40) << "[Name]" - << "\t" << std::fixed << std::setprecision(2) << std::setw(7) << "[ms]" - << "\t" << std::fixed << std::setprecision(2) << std::setw(6) << "[%]"; + << "\t" << std::setw(10) << "[Op]" + << "\t" << std::fixed << std::setprecision(3) << std::setw(9) << "[ms]" + << "\t" << std::fixed << std::setprecision(3) << std::setw(7) << "[%]" + << "\t"; LOG(INFO) << stream.str(); } -void StatSummarizer::PrintColumns(const char* name, const double time_ms, +void StatSummarizer::PrintColumns(const char* name, const char* op, + const double time_ms, const double percentage) { std::stringstream stream; - stream << std::setw(40) << name << "\t" << std::fixed << std::setprecision(2) - << std::setw(7) << time_ms << "\t" << std::fixed - << std::setprecision(2) << std::setw(6) << percentage; + stream << std::setw(40) << name << "\t" << std::setw(10) << op << "\t" + << std::fixed << std::setprecision(3) << std::setw(9) << time_ms + << "\t" << std::fixed << std::setprecision(3) << std::setw(7) + << percentage << "\t"; LOG(INFO) << stream.str(); } @@ -69,16 +83,26 @@ void StatSummarizer::PrintStepStats() { LOG(INFO) << "Total time (us): " << run_total_us_; std::priority_queue<std::pair<double, string> > timings; + + LOG(INFO) << timing_totals_.size() << " entries"; + LOG(INFO) << "========== Sorted by run order (ms) =========="; PrintHeaders(); - for (auto entry : timing_totals_) { + for (auto node_name : nodes_) { + if (timing_totals_.find(node_name) == timing_totals_.end()) { + continue; + } + + int64 total_time = timing_totals_[node_name]; + const double avg_time_ms = - entry.second / static_cast<double>(num_runs_) / 1000.0; + total_time / static_cast<double>(num_runs_) / 1000.0; const double overall_percentage = 100.0 * avg_time_ms / avg_total_ms; - PrintColumns(entry.first.c_str(), avg_time_ms, overall_percentage); - timings.push(std::pair<double, string>(avg_time_ms, entry.first)); + PrintColumns(node_name.c_str(), node_types_[node_name].c_str(), avg_time_ms, + overall_percentage); + timings.push(std::pair<double, string>(avg_time_ms, node_name)); } LOG(INFO); @@ -90,7 +114,8 @@ void StatSummarizer::PrintStepStats() { timings.pop(); const double overall_percentage = 100.0 * entry.first / avg_total_ms; - PrintColumns(entry.second.c_str(), entry.first, overall_percentage); + PrintColumns(entry.second.c_str(), node_types_[entry.second].c_str(), + entry.first, overall_percentage); ++num_printed; } LOG(INFO); |