aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/optimize_for_inference.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/optimize_for_inference.py')
-rw-r--r--tensorflow/python/tools/optimize_for_inference.py24
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py
index b95ae00cbd..165b84673c 100644
--- a/tensorflow/python/tools/optimize_for_inference.py
+++ b/tensorflow/python/tools/optimize_for_inference.py
@@ -57,13 +57,17 @@ from __future__ import print_function
import os
-import tensorflow as tf
-
from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_io
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags as flags_lib
+from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
-flags = tf.app.flags
+flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("output", "", """File to save the output graph to.""")
@@ -73,17 +77,17 @@ flags.DEFINE_string("output_names", "",
flags.DEFINE_boolean("frozen_graph", True,
"""If true, the input graph is a binary frozen GraphDef
file; if false, it is a text GraphDef proto file.""")
-flags.DEFINE_integer("placeholder_type_enum", tf.float32.as_datatype_enum,
+flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
"""The AttrValue enum to use for placeholders.""")
def main(unused_args):
- if not tf.gfile.Exists(FLAGS.input):
+ if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
- input_graph_def = tf.GraphDef()
- with tf.gfile.Open(FLAGS.input, "r") as f:
+ input_graph_def = graph_pb2.GraphDef()
+ with gfile.Open(FLAGS.input, "r") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)
@@ -96,14 +100,14 @@ def main(unused_args):
FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
if FLAGS.frozen_graph:
- f = tf.gfile.FastGFile(FLAGS.output, "w")
+ f = gfile.FastGFile(FLAGS.output, "w")
f.write(output_graph_def.SerializeToString())
else:
- tf.train.write_graph(output_graph_def,
+ graph_io.write_graph(output_graph_def,
os.path.dirname(FLAGS.output),
os.path.basename(FLAGS.output))
return 0
if __name__ == "__main__":
- tf.app.run()
+ app.run()