diff options
author | 2017-04-28 17:43:15 -0800 | |
---|---|---|
committer | 2017-04-28 19:05:08 -0700 | |
commit | ad3c84b58bb42c87ae8f38b81f75447afcc86d5f (patch) | |
tree | b16d8e024cbcc5e7647777fe1e09114e6a4d4252 /tensorflow | |
parent | 736a2eb3de5031e967d88c63f66c812e07fbd26c (diff) |
Added graph structure output to summarize_graph
Change: 154606362
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/tools/graph_transforms/summarize_graph_main.cc | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index f45dfbba0c..b893465fd9 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -102,7 +102,18 @@ void PrintBenchmarkUsage(const std::vector<const NodeDef*> placeholders, std::cout << std::endl; } -Status SummarizeGraph(const GraphDef& graph, const string& graph_path) { +Status PrintStructure(const GraphDef& graph) { + GraphDef sorted_graph; + TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph)); + for (const NodeDef& node : sorted_graph.node()) { + std::cout << node.name() << " (" << node.op() << "): [" + << str_util::Join(node.input(), ", ") << "]" << std::endl; + } + return Status::OK(); +} + +Status SummarizeGraph(const GraphDef& graph, const string& graph_path, + bool print_structure) { std::vector<const NodeDef*> placeholders; std::vector<const NodeDef*> variables; for (const NodeDef& node : graph.node()) { @@ -233,13 +244,20 @@ Status SummarizeGraph(const GraphDef& graph, const string& graph_path) { PrintBenchmarkUsage(placeholders, variables, outputs, graph_path); + if (print_structure) { + TF_RETURN_IF_ERROR(PrintStructure(graph)); + } + return Status::OK(); } int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) { string in_graph = ""; + bool print_structure = false; std::vector<Flag> flag_list = { Flag("in_graph", &in_graph, "input graph file name"), + Flag("print_structure", &print_structure, + "whether to print the network connections of the graph"), }; string usage = Flags::Usage(argv[0], flag_list); @@ -269,7 +287,8 @@ int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) { return -1; } - Status summarize_result = SummarizeGraph(graph_def, in_graph); + Status summarize_result = + SummarizeGraph(graph_def, in_graph, print_structure); if (!summarize_result.ok()) { LOG(ERROR) << summarize_result.error_message() << "\n" << usage; return -1; |