diff options
Diffstat (limited to 'tensorflow/python/tools/optimize_for_inference.py')
-rw-r--r-- | tensorflow/python/tools/optimize_for_inference.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py index b95ae00cbd..165b84673c 100644 --- a/tensorflow/python/tools/optimize_for_inference.py +++ b/tensorflow/python/tools/optimize_for_inference.py @@ -57,13 +57,17 @@ from __future__ import print_function import os -import tensorflow as tf - from google.protobuf import text_format +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_io +from tensorflow.python.platform import app +from tensorflow.python.platform import flags as flags_lib +from tensorflow.python.platform import gfile from tensorflow.python.tools import optimize_for_inference_lib -flags = tf.app.flags +flags = flags_lib FLAGS = flags.FLAGS flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""") flags.DEFINE_string("output", "", """File to save the output graph to.""") @@ -73,17 +77,17 @@ flags.DEFINE_string("output_names", "", flags.DEFINE_boolean("frozen_graph", True, """If true, the input graph is a binary frozen GraphDef file; if false, it is a text GraphDef proto file.""") -flags.DEFINE_integer("placeholder_type_enum", tf.float32.as_datatype_enum, +flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum, """The AttrValue enum to use for placeholders.""") def main(unused_args): - if not tf.gfile.Exists(FLAGS.input): + if not gfile.Exists(FLAGS.input): print("Input graph file '" + FLAGS.input + "' does not exist!") return -1 - input_graph_def = tf.GraphDef() - with tf.gfile.Open(FLAGS.input, "r") as f: + input_graph_def = graph_pb2.GraphDef() + with gfile.Open(FLAGS.input, "r") as f: data = f.read() if FLAGS.frozen_graph: input_graph_def.ParseFromString(data) @@ -96,14 +100,14 @@ def main(unused_args): FLAGS.output_names.split(","), FLAGS.placeholder_type_enum) if FLAGS.frozen_graph: - f = tf.gfile.FastGFile(FLAGS.output, "w") + f = gfile.FastGFile(FLAGS.output, "w") f.write(output_graph_def.SerializeToString()) else: - tf.train.write_graph(output_graph_def, + graph_io.write_graph(output_graph_def, os.path.dirname(FLAGS.output), os.path.basename(FLAGS.output)) return 0 if __name__ == "__main__": - tf.app.run() + app.run() |