aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/transform_graph.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-01-09 16:46:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-09 17:03:54 -0800
commit80679a6a741d2caa00c9b8bd0fa309f9cbbe2905 (patch)
treef5e60b5e7b9595d3eb0af09745feb32f86928144 /tensorflow/tools/graph_transforms/transform_graph.cc
parentbfa0fa079de45a9ce3dc54c0074680f02434cedb (diff)
Add new transforms and compare_graph tool
Change: 144024968
Diffstat (limited to 'tensorflow/tools/graph_transforms/transform_graph.cc')
-rw-r--r--tensorflow/tools/graph_transforms/transform_graph.cc12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 5e71b0bd5c..c48be92bb9 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -129,12 +129,15 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
string inputs_string = "";
string outputs_string = "";
string transforms_string = "";
+ bool output_as_text = false;
std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"),
Flag("out_graph", &out_graph, "output graph file name"),
Flag("inputs", &inputs_string, "inputs"),
Flag("outputs", &outputs_string, "outputs"),
Flag("transforms", &transforms_string, "list of transforms"),
+ Flag("output_as_text", &output_as_text,
+ "whether to write the graph in text protobuf format"),
};
string usage = Flags::Usage(argv[0], flag_list);
usage += "\nTransforms are:\n";
@@ -185,7 +188,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
}
GraphDef graph_def;
- Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
+ Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
if (!load_status.ok()) {
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
<< load_status.error_message();
@@ -202,7 +205,12 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
return -1;
}
- Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
+ Status save_status;
+ if (output_as_text) {
+ save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
+ } else {
+ save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
+ }
if (!save_status.ok()) {
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
<< save_status.error_message();