aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/stat_summarizer.cc
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrew.harp@gmail.com>2016-04-18 15:20:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-18 16:31:49 -0700
commit5c9e7d39187321f27cfc1ebb32064f49bb337313 (patch)
treed47533ecfbe8bb6bd5f1f509b03877d5df3fc48d /tensorflow/core/util/stat_summarizer.cc
parent2092fb45f773db7ff0cfc090cc245594999e4999 (diff)
Have StatSummarizer print OP types as well as name.
Change: 120172505
Diffstat (limited to 'tensorflow/core/util/stat_summarizer.cc')
-rw-r--r--tensorflow/core/util/stat_summarizer.cc47
1 files changed, 36 insertions, 11 deletions
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc
index 4586a06995..8a3c92c92b 100644
--- a/tensorflow/core/util/stat_summarizer.cc
+++ b/tensorflow/core/util/stat_summarizer.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <sstream>
#include <string>
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -28,6 +29,15 @@ limitations under the License.
namespace tensorflow {
+StatSummarizer::StatSummarizer(const tensorflow::GraphDef& tensorflow_graph) {
+ LOG(INFO) << "StatSummarizer found " << tensorflow_graph.node_size()
+ << " nodes";
+ for (const auto& node : tensorflow_graph.node()) {
+ nodes_.push_back(node.name());
+ node_types_[node.name()] = node.op();
+ }
+}
+
void StatSummarizer::ProcessStepStats(const StepStats& step_stats) {
++num_runs_;
int64 curr_total = 0;
@@ -48,17 +58,21 @@ void StatSummarizer::ProcessStepStats(const StepStats& step_stats) {
void StatSummarizer::PrintHeaders() {
std::stringstream stream;
stream << std::setw(40) << "[Name]"
- << "\t" << std::fixed << std::setprecision(2) << std::setw(7) << "[ms]"
- << "\t" << std::fixed << std::setprecision(2) << std::setw(6) << "[%]";
+ << "\t" << std::setw(10) << "[Op]"
+ << "\t" << std::fixed << std::setprecision(3) << std::setw(9) << "[ms]"
+ << "\t" << std::fixed << std::setprecision(3) << std::setw(7) << "[%]"
+ << "\t";
LOG(INFO) << stream.str();
}
-void StatSummarizer::PrintColumns(const char* name, const double time_ms,
+void StatSummarizer::PrintColumns(const char* name, const char* op,
+ const double time_ms,
const double percentage) {
std::stringstream stream;
- stream << std::setw(40) << name << "\t" << std::fixed << std::setprecision(2)
- << std::setw(7) << time_ms << "\t" << std::fixed
- << std::setprecision(2) << std::setw(6) << percentage;
+ stream << std::setw(40) << name << "\t" << std::setw(10) << op << "\t"
+ << std::fixed << std::setprecision(3) << std::setw(9) << time_ms
+ << "\t" << std::fixed << std::setprecision(3) << std::setw(7)
+ << percentage << "\t";
LOG(INFO) << stream.str();
}
@@ -69,16 +83,26 @@ void StatSummarizer::PrintStepStats() {
LOG(INFO) << "Total time (us): " << run_total_us_;
std::priority_queue<std::pair<double, string> > timings;
+
+ LOG(INFO) << timing_totals_.size() << " entries";
+
LOG(INFO) << "========== Sorted by run order (ms) ==========";
PrintHeaders();
- for (auto entry : timing_totals_) {
+ for (auto node_name : nodes_) {
+ if (timing_totals_.find(node_name) == timing_totals_.end()) {
+ continue;
+ }
+
+ int64 total_time = timing_totals_[node_name];
+
const double avg_time_ms =
- entry.second / static_cast<double>(num_runs_) / 1000.0;
+ total_time / static_cast<double>(num_runs_) / 1000.0;
const double overall_percentage = 100.0 * avg_time_ms / avg_total_ms;
- PrintColumns(entry.first.c_str(), avg_time_ms, overall_percentage);
- timings.push(std::pair<double, string>(avg_time_ms, entry.first));
+ PrintColumns(node_name.c_str(), node_types_[node_name].c_str(), avg_time_ms,
+ overall_percentage);
+ timings.push(std::pair<double, string>(avg_time_ms, node_name));
}
LOG(INFO);
@@ -90,7 +114,8 @@ void StatSummarizer::PrintStepStats() {
timings.pop();
const double overall_percentage = 100.0 * entry.first / avg_total_ms;
- PrintColumns(entry.second.c_str(), entry.first, overall_percentage);
+ PrintColumns(entry.second.c_str(), node_types_[entry.second].c_str(),
+ entry.first, overall_percentage);
++num_printed;
}
LOG(INFO);