aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-04-28 17:43:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 19:05:08 -0700
commitad3c84b58bb42c87ae8f38b81f75447afcc86d5f (patch)
treeb16d8e024cbcc5e7647777fe1e09114e6a4d4252 /tensorflow
parent736a2eb3de5031e967d88c63f66c812e07fbd26c (diff)
Added graph structure output to summarize_graph
Change: 154606362
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/tools/graph_transforms/summarize_graph_main.cc23
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc
index f45dfbba0c..b893465fd9 100644
--- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc
+++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc
@@ -102,7 +102,18 @@ void PrintBenchmarkUsage(const std::vector<const NodeDef*> placeholders,
std::cout << std::endl;
}
-Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
+Status PrintStructure(const GraphDef& graph) {
+ GraphDef sorted_graph;
+ TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
+ for (const NodeDef& node : sorted_graph.node()) {
+ std::cout << node.name() << " (" << node.op() << "): ["
+ << str_util::Join(node.input(), ", ") << "]" << std::endl;
+ }
+ return Status::OK();
+}
+
+Status SummarizeGraph(const GraphDef& graph, const string& graph_path,
+ bool print_structure) {
std::vector<const NodeDef*> placeholders;
std::vector<const NodeDef*> variables;
for (const NodeDef& node : graph.node()) {
@@ -233,13 +244,20 @@ Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
+ if (print_structure) {
+ TF_RETURN_IF_ERROR(PrintStructure(graph));
+ }
+
return Status::OK();
}
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
string in_graph = "";
+ bool print_structure = false;
std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"),
+ Flag("print_structure", &print_structure,
+ "whether to print the network connections of the graph"),
};
string usage = Flags::Usage(argv[0], flag_list);
@@ -269,7 +287,8 @@ int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
return -1;
}
- Status summarize_result = SummarizeGraph(graph_def, in_graph);
+ Status summarize_result =
+ SummarizeGraph(graph_def, in_graph, print_structure);
if (!summarize_result.ok()) {
LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
return -1;