diff options
Diffstat (limited to 'tensorflow/tools/tfprof/internal/tfprof_stats.cc')
-rw-r--r-- | tensorflow/tools/tfprof/internal/tfprof_stats.cc | 77 |
1 files changed, 27 insertions, 50 deletions
diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats.cc b/tensorflow/tools/tfprof/internal/tfprof_stats.cc index 64da7ae7cf..f5b8dad4e2 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_stats.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_stats.cc @@ -29,17 +29,16 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph, std::unique_ptr<RunMetadata> run_meta, std::unique_ptr<OpLog> op_log, std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader) - : has_code_traces_(false), - graph_(std::move(graph)), + : graph_(std::move(graph)), ckpt_reader_(std::move(ckpt_reader)) { CHECK(graph_) << "Must at least have GraphDef"; printf("Parsing Inputs...\n"); ParseGraph(); if (run_meta && run_meta->has_step_stats()) { - AddRunMeta(0, std::move(run_meta)); + ParseRunMeta(0, std::move(run_meta)); } - AddOpLog(std::move(op_log)); + ParseOpLog(std::move(op_log)); if (ckpt_reader_) { for (const auto& v : ckpt_reader_->GetVariableToShapeMap()) { @@ -49,48 +48,27 @@ TFStats::TFStats(std::unique_ptr<GraphDef> graph, } } } -} -void TFStats::BuildView(const string& cmd) { - if (cmd == kCmds[0] && !scope_view_) { - scope_view_.reset(new TFScope(ckpt_reader_.get())); - for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { - scope_view_->AddNode(it->second.get()); - } - scope_view_->Build(); - } - if (cmd == kCmds[1] && !graph_view_) { - graph_view_.reset(new TFGraph(ckpt_reader_.get())); - for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { - graph_view_->AddNode(it->second.get()); - } - graph_view_->Build(); - } - if (cmd == kCmds[2] && !code_view_) { - code_view_.reset(new TFCode()); - for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { - code_view_->AddNode(it->second.get()); - } - code_view_->Build(); - } - if (cmd == kCmds[3] && !op_view_) { - op_view_.reset(new TFOp()); - for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { - op_view_->AddNode(it->second.get()); - } - op_view_->Build(); - } -} + printf("Preparing Views...\n"); + scope_view_ = std::unique_ptr<TFScope>(new TFScope(ckpt_reader_.get())); + graph_view_ = std::unique_ptr<TFGraph>(new TFGraph(ckpt_reader_.get())); + code_view_ = std::unique_ptr<TFCode>(new TFCode()); + op_view_ = std::unique_ptr<TFOp>(new TFOp()); -void TFStats::BuildAllViews() { - std::vector<string> cmds_str(kCmds, kCmds + sizeof(kCmds) / sizeof(*kCmds)); - for (const string& cmd : cmds_str) { - BuildView(cmd); - } + for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) { + scope_view_->AddNode(it->second.get()); + graph_view_->AddNode(it->second.get()); + code_view_->AddNode(it->second.get()); + op_view_->AddNode(it->second.get()); + } + scope_view_->Build(); + graph_view_->Build(); + code_view_->Build(); + op_view_->Build(); } const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd, - const Options& opts) const { + const Options& opts) { if (!Validate(opts)) { return empty_graph_node_; } @@ -104,8 +82,8 @@ const TFGraphNodeProto& TFStats::ShowGraphNode(const string& cmd, } } -const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode( - const string& cmd, const Options& opts) const { +const TFMultiGraphNodeProto& TFStats::ShowMultiGraphNode(const string& cmd, + const Options& opts) { if (!Validate(opts)) { return empty_multi_graph_node_; } @@ -152,7 +130,7 @@ void TFStats::ParseGraph() { } } -void TFStats::AddOpLog(std::unique_ptr<OpLog> op_log) { +void TFStats::ParseOpLog(std::unique_ptr<OpLog> op_log) { if (!op_log) { return; } @@ -166,13 +144,12 @@ void TFStats::AddOpLog(std::unique_ptr<OpLog> op_log) { node->second->AddFloatOps(entry.float_ops()); } if (entry.has_code_def()) { - has_code_traces_ = true; node->second->AddCode(entry.code_def()); } } } -void TFStats::AddRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) { +void TFStats::ParseRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) { if (!run_meta || !run_meta->has_step_stats()) { fprintf(stderr, "Invalid RunMetadata for step %lld\n", step); return; @@ -199,7 +176,7 @@ void TFStats::AddRunMeta(int64 step, std::unique_ptr<RunMetadata> run_meta) { } } -bool TFStats::Validate(const Options& opts) const { +bool TFStats::Validate(const Options& opts) { if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) { fprintf(stderr, "Options -step=%lld not found\n", opts.step); return false; @@ -207,9 +184,9 @@ bool TFStats::Validate(const Options& opts) const { return true; } -void TFStats::AddNodeForTest(int64 step, std::unique_ptr<TFGraphNode> node) { - steps_.insert(step); - nodes_map_[node->name()] = std::move(node); +void TFStats::AddNodeForTest(const string& name, + std::unique_ptr<TFGraphNode> node) { + nodes_map_[name] = std::move(node); } } // namespace tfprof } // namespace tensorflow |