aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/summarize_graph_main.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-01-30 21:10:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-30 21:42:35 -0800
commit2490fcade664f5dc5f6ef9b88fc7d5bb30cf738a (patch)
tree871d290b457756fccc6cda3f92ad2b3bbf517c1b /tensorflow/tools/graph_transforms/summarize_graph_main.cc
parent8691453b900aef43bf4a3966a5a96b14fbb7bfc3 (diff)
Add parameter counts to graph summarization
Change: 146076665
Diffstat (limited to 'tensorflow/tools/graph_transforms/summarize_graph_main.cc')
-rw-r--r--tensorflow/tools/graph_transforms/summarize_graph_main.cc34
1 files changed, 21 insertions, 13 deletions
diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc
index 55b55e0a15..8c1c007b1d 100644
--- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc
+++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc
@@ -82,19 +82,11 @@ Status SummarizeGraph(const GraphDef& graph) {
std::cout << std::endl;
}
- int const_count = 0;
- int variable_count = 0;
- int identity_count = 0;
+ int64 const_parameter_count = 0;
+ int64 variable_parameter_count = 0;
int control_edge_count = 0;
std::map<string, int> device_counts;
for (const NodeDef& node : graph.node()) {
- if (node.op() == "Const") {
- ++const_count;
- } else if (node.op() == "Variable") {
- ++variable_count;
- } else if (node.op() == "Identity") {
- ++identity_count;
- }
for (const string& input : node.input()) {
if (input.substr(0, 1) == "^") {
++control_edge_count;
@@ -103,11 +95,27 @@ Status SummarizeGraph(const GraphDef& graph) {
if (node.device() != "") {
++device_counts[node.device()];
}
+ if ((node.op() == "Const") || (node.op() == "Variable")) {
+ Tensor tensor;
+ if (tensor.FromProto(node.attr().at("value").tensor())) {
+ const size_t num_elements = tensor.NumElements();
+ if (node.op() == "Const") {
+ const_parameter_count += num_elements;
+ } else {
+ variable_parameter_count += num_elements;
+ }
+ } else {
+ LOG(WARNING) << "Decoding Tensor failed for node" << node.name();
+ }
+ }
}
- std::cout << "Found " << const_count << " consts, " << variable_count
- << " variables, " << identity_count << " identities, and "
- << control_edge_count << " control_edges" << std::endl;
+ std::cout << "Found " << const_parameter_count << " ("
+ << strings::HumanReadableNum(const_parameter_count)
+ << ") const parameters, " << variable_parameter_count << " ("
+ << strings::HumanReadableNum(variable_parameter_count)
+ << ") variable parameters, and " << control_edge_count
+ << " control_edges" << std::endl;
if (!device_counts.empty()) {
for (const auto& device_info : device_counts) {
std::cout << device_info.second << " nodes assigned to device '"