diff options
author | 2017-01-30 21:10:38 -0800 | |
---|---|---|
committer | 2017-01-30 21:42:35 -0800 | |
commit | 2490fcade664f5dc5f6ef9b88fc7d5bb30cf738a (patch) | |
tree | 871d290b457756fccc6cda3f92ad2b3bbf517c1b /tensorflow/tools/graph_transforms/summarize_graph_main.cc | |
parent | 8691453b900aef43bf4a3966a5a96b14fbb7bfc3 (diff) |
Add parameter counts to graph summarization
Change: 146076665
Diffstat (limited to 'tensorflow/tools/graph_transforms/summarize_graph_main.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/summarize_graph_main.cc | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index 55b55e0a15..8c1c007b1d 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -82,19 +82,11 @@ Status SummarizeGraph(const GraphDef& graph) { std::cout << std::endl; } - int const_count = 0; - int variable_count = 0; - int identity_count = 0; + int64 const_parameter_count = 0; + int64 variable_parameter_count = 0; int control_edge_count = 0; std::map<string, int> device_counts; for (const NodeDef& node : graph.node()) { - if (node.op() == "Const") { - ++const_count; - } else if (node.op() == "Variable") { - ++variable_count; - } else if (node.op() == "Identity") { - ++identity_count; - } for (const string& input : node.input()) { if (input.substr(0, 1) == "^") { ++control_edge_count; @@ -103,11 +95,27 @@ Status SummarizeGraph(const GraphDef& graph) { if (node.device() != "") { ++device_counts[node.device()]; } + if ((node.op() == "Const") || (node.op() == "Variable")) { + Tensor tensor; + if (tensor.FromProto(node.attr().at("value").tensor())) { + const size_t num_elements = tensor.NumElements(); + if (node.op() == "Const") { + const_parameter_count += num_elements; + } else { + variable_parameter_count += num_elements; + } + } else { + LOG(WARNING) << "Decoding Tensor failed for node" << node.name(); + } + } } - std::cout << "Found " << const_count << " consts, " << variable_count - << " variables, " << identity_count << " identities, and " - << control_edge_count << " control_edges" << std::endl; + std::cout << "Found " << const_parameter_count << " (" + << strings::HumanReadableNum(const_parameter_count) + << ") const parameters, " << variable_parameter_count << " (" + << strings::HumanReadableNum(variable_parameter_count) + << ") variable parameters, and " << control_edge_count + << " control_edges" << std::endl; if (!device_counts.empty()) { for (const auto& device_info : device_counts) { std::cout << device_info.second << " nodes assigned to device '" |