From 2b351f224df81121cdcf8131d84be0e3f43d407c Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Fri, 6 Jan 2017 12:07:23 -0800 Subject: Convert tf.flags usage to argparse. Move use of FLAGS globals into main() only. Change: 143799731 --- tensorflow/python/client/notebook.py | 33 ++++-- tensorflow/python/debug/cli/offline_analyzer.py | 41 ++++--- tensorflow/python/debug/examples/debug_errors.py | 37 +++++-- .../python/debug/examples/debug_fibonacci.py | 29 +++-- tensorflow/python/debug/examples/debug_mnist.py | 51 ++++++--- .../python/debug/examples/debug_tflearn_iris.py | 60 +++++++---- tensorflow/python/framework/gen_docs_combined.py | 25 +++-- .../example/saved_model_half_plus_two.py | 25 +++-- tensorflow/python/tools/freeze_graph.py | 118 +++++++++++++++------ tensorflow/python/tools/inspect_checkpoint.py | 35 ++++-- tensorflow/python/tools/optimize_for_inference.py | 56 +++++++--- tensorflow/python/tools/strip_unused.py | 65 +++++++++--- 12 files changed, 416 insertions(+), 159 deletions(-) diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py index b18fb6f889..8babe35b32 100644 --- a/tensorflow/python/client/notebook.py +++ b/tensorflow/python/client/notebook.py @@ -31,27 +31,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import os import socket import sys +from tensorflow.python.platform import app + # pylint: disable=g-import-not-at-top # Official recommended way of turning on fast protocol buffers as of 10/21/14 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" -from tensorflow.python.platform import app -from tensorflow.python.platform import flags - -FLAGS = flags.FLAGS -flags.DEFINE_string( - "password", None, - "Password to require. If set, the server will allow public access." - " Only used if notebook config file does not exist.") +FLAGS = None -flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks", - "root location where to store notebooks") ORIG_ARGV = sys.argv # Main notebook process calls itself with argv[1]="kernel" to start kernel @@ -108,6 +102,21 @@ def main(unused_argv): if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--password", + type=str, + default=None, + help="""\ + Password to require. If set, the server will allow public access. Only + used if notebook config file does not exist.\ + """) + parser.add_argument( + "--notebook_dir", + type=str, + default="experimental/brain/notebooks", + help="root location where to store notebooks") + # When the user starts the main notebook process, we don't touch sys.argv. # When the main process launches kernel subprocesses, it writes all flags # to a tmpfile and sets --flagfile to that tmpfile, so for kernel @@ -118,4 +127,6 @@ if __name__ == "__main__": # Drop everything except --flagfile. sys.argv = ([sys.argv[0]] + [x for x in sys.argv[1:] if x.startswith("--flagfile")]) - app.run() + + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/cli/offline_analyzer.py b/tensorflow/python/debug/cli/offline_analyzer.py index d91d82ff95..3f78c54732 100644 --- a/tensorflow/python/debug/cli/offline_analyzer.py +++ b/tensorflow/python/debug/cli/offline_analyzer.py @@ -17,23 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import sys # Google-internal import(s). + from tensorflow.python.debug import debug_data from tensorflow.python.debug.cli import analyzer_cli from tensorflow.python.platform import app -from tensorflow.python.platform import flags - -FLAGS = flags.FLAGS -flags.DEFINE_string("dump_dir", "", "tfdbg dump directory to load") -flags.DEFINE_string("ui_type", "curses", - "Command-line user interface type (curses | readline)") -flags.DEFINE_boolean( - "log_usage", True, "Whether the usage of this tool is to be logged") -flags.DEFINE_boolean( - "validate_graph", True, - "Whether the dumped tensors will be validated against the GraphDefs") def main(_): @@ -58,4 +49,30 @@ def main(_): if __name__ == "__main__": - app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--dump_dir", type=str, default="", help="tfdbg dump directory to load") + parser.add_argument( + "--log_usage", + type="bool", + nargs="?", + const=True, + default=True, + help="Whether the usage of this tool is to be logged") + parser.add_argument( + "--ui_type", + type=str, + default="curses" + help="Command-line user interface type (curses | readline)") + parser.add_argument( + "--validate_graph", + nargs="?", + const=True, + type="bool", + default=True, + help="""\ + Whether the dumped tensors will be validated against the GraphDefs\ + """) + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/debug_errors.py b/tensorflow/python/debug/examples/debug_errors.py index 6369853658..855c7cb684 100644 --- a/tensorflow/python/debug/examples/debug_errors.py +++ b/tensorflow/python/debug/examples/debug_errors.py @@ -17,20 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + import numpy as np import tensorflow as tf from tensorflow.python import debug as tf_debug -flags = tf.app.flags -FLAGS = flags.FLAGS -flags.DEFINE_string("error", "shape_mismatch", "Type of the error to generate " - "(shape_mismatch | uninitialized_variable | no_error).") -flags.DEFINE_string("ui_type", "curses", - "Command-line user interface type (curses | readline)") -flags.DEFINE_boolean("debug", False, - "Use debugger to track down bad values during training") - def main(_): sess = tf.Session() @@ -60,4 +54,27 @@ def main(_): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--error", + type=str, + default="shape_mismatch", + help="""\ + Type of the error to generate (shape_mismatch | uninitialized_variable | + no_error).\ + """) + parser.add_argument( + "--ui_type", + type=str, + default="curses" + help="Command-line user interface type (curses | readline)") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/debug_fibonacci.py b/tensorflow/python/debug/examples/debug_fibonacci.py index 99643316be..14722ecd08 100644 --- a/tensorflow/python/debug/examples/debug_fibonacci.py +++ b/tensorflow/python/debug/examples/debug_fibonacci.py @@ -17,19 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.python import debug as tf_debug -flags = tf.app.flags -FLAGS = flags.FLAGS -flags.DEFINE_integer("tensor_size", 30, - "Size of tensor. E.g., if the value is 30, the tensors " - "will have shape [30, 30].") -flags.DEFINE_integer("length", 20, - "Length of the fibonacci sequence to compute.") +FLAGS = None def main(_): @@ -54,4 +51,20 @@ def main(_): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--tensor_size", + type=int, + default=30, + help="""\ + Size of tensor. E.g., if the value is 30, the tensors will have shape + [30, 30].\ + """) + parser.add_argument( + "--length", + type=int, + default=20, + help="Length of the fibonacci sequence to compute.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/debug_mnist.py b/tensorflow/python/debug/examples/debug_mnist.py index 526b85762d..1b999aff7c 100644 --- a/tensorflow/python/debug/examples/debug_mnist.py +++ b/tensorflow/python/debug/examples/debug_mnist.py @@ -24,22 +24,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.python import debug as tf_debug -flags = tf.app.flags -FLAGS = flags.FLAGS -flags.DEFINE_integer("max_steps", 10, "Number of steps to run trainer.") -flags.DEFINE_integer("train_batch_size", 100, - "Batch size used during training.") -flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.") -flags.DEFINE_string("data_dir", "/tmp/mnist_data", "Directory for storing data") -flags.DEFINE_string("ui_type", "curses", - "Command-line user interface type (curses | readline)") -flags.DEFINE_boolean("debug", False, - "Use debugger to track down bad values during training") IMAGE_SIZE = 28 HIDDEN_SIZE = 500 @@ -137,4 +129,39 @@ def main(_): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--max_steps", + type=int, + default=10, + help="Number of steps to run trainer.") + parser.add_argument( + "--train_batch_size", + type=int, + default=100, + help="Batch size used during training.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.025, + help="Initial learning rate.") + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/mnist_data", + help="Directory for storing data") + parser.add_argument( + "--ui_type", + type=str, + default="curses" + help="Command-line user interface type (curses | readline)") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 9c61ed5bff..ea1f588f57 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import os +import sys import tempfile import numpy as np @@ -26,33 +28,26 @@ import tensorflow as tf from tensorflow.python import debug as tf_debug -flags = tf.app.flags -FLAGS = flags.FLAGS -flags.DEFINE_string("data_dir", "/tmp/iris_data", - "Directory to save the training and test data in.") -flags.DEFINE_string("model_dir", "", "Directory to save the trained model in.") -flags.DEFINE_integer("train_steps", 10, "Number of steps to run trainer.") -flags.DEFINE_string("ui_type", "curses", - "Command-line user interface type (curses | readline)") -flags.DEFINE_boolean("debug", False, - "Use debugger to track down bad values during training") # URLs to download data sets from, if necessary. IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv" IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv" -def maybe_download_data(): +def maybe_download_data(data_dir): """Download data sets if necessary. + Args: + data_dir: Path to where data should be downloaded. + Returns: Paths to the training and test data files. """ - if not os.path.isdir(FLAGS.data_dir): - os.makedirs(FLAGS.data_dir) + if not os.path.isdir(data_dir): + os.makedirs(data_dir) - training_data_path = os.path.join(FLAGS.data_dir, + training_data_path = os.path.join(data_dir, os.path.basename(IRIS_TRAINING_DATA_URL)) if not os.path.isfile(training_data_path): train_file = open(training_data_path, "wt") @@ -61,8 +56,7 @@ def maybe_download_data(): print("Training data are downloaded to %s" % train_file.name) - test_data_path = os.path.join(FLAGS.data_dir, - os.path.basename(IRIS_TEST_DATA_URL)) + test_data_path = os.path.join(data_dir, os.path.basename(IRIS_TEST_DATA_URL)) if not os.path.isfile(test_data_path): test_file = open(test_data_path, "wt") urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name) @@ -74,7 +68,7 @@ def maybe_download_data(): def main(_): - training_data_path, test_data_path = maybe_download_data() + training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir) # Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( @@ -115,4 +109,34 @@ def main(_): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/iris_data", + help="Directory to save the training and test data in.") + parser.add_argument( + "--model_dir", + type=str, + default="", + help="Directory to save the trained model in.") + parser.add_argument( + "--train_steps", + type=int, + default=10, + help="Number of steps to run trainer.") + parser.add_argument( + "--ui_type", + type=str, + default="curses" + help="Command-line user interface type (curses | readline)") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 377ccb7c9b..031c9668ab 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import collections import os.path import sys @@ -31,12 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import docs from tensorflow.python.framework import framework_lib - -tf.flags.DEFINE_string("out_dir", None, - "Directory to which docs should be written.") -tf.flags.DEFINE_boolean("print_hidden_regex", False, - "Dump a regular expression matching any hidden symbol") -FLAGS = tf.flags.FLAGS +FLAGS = None PREFIX_TEXT = """ @@ -309,4 +305,19 @@ def main(unused_argv): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--out_dir", + type=str, + default=None, + help="Directory to which docs should be written.") + parser.add_argument( + "--print_hidden_regex", + type="bool", + nargs="?", + const=True, + default=False, + help="Dump a regular expression matching any hidden symbol") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py index d65bc0dfa8..a71747da10 100644 --- a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py +++ b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py @@ -31,7 +31,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import os +import sys + import tensorflow as tf from tensorflow.core.protobuf import meta_graph_pb2 @@ -42,13 +45,7 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.util import compat -tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two", - "Directory where to ouput SavedModel.") -tf.app.flags.DEFINE_string("output_dir_pbtxt", - "/tmp/saved_model_half_plus_two_pbtxt", - "Directory where to ouput the text format of " - "SavedModel.") -FLAGS = tf.flags.FLAGS +FLAGS = None def _write_assets(assets_directory, assets_filename): @@ -172,4 +169,16 @@ def main(_): if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="/tmp/saved_model_half_plus_two", + help="Directory where to ouput SavedModel.") + parser.add_argument( + "--output_dir_pbtxt", + type=str, + default="/tmp/saved_model_half_plus_two_pbtxt", + help="Directory where to ouput the text format of SavedModel.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 3ccb1b782c..bdd59eeb6b 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -37,6 +37,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 @@ -45,37 +48,23 @@ from tensorflow.python.client import session from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer from tensorflow.python.platform import app -from tensorflow.python.platform import flags from tensorflow.python.platform import gfile from tensorflow.python.training import saver as saver_lib -FLAGS = flags.FLAGS - -flags.DEFINE_string("input_graph", "", - """TensorFlow 'GraphDef' file to load.""") -flags.DEFINE_string("input_saver", "", """TensorFlow saver file to load.""") -flags.DEFINE_string("input_checkpoint", "", - """TensorFlow variables file to load.""") -flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""") -flags.DEFINE_boolean("input_binary", False, - """Whether the input files are in binary format.""") -flags.DEFINE_string("output_node_names", "", - """The name of the output nodes, comma separated.""") -flags.DEFINE_string("restore_op_name", "save/restore_all", - """The name of the master restore operator.""") -flags.DEFINE_string("filename_tensor_name", "save/Const:0", - """The name of the tensor holding the save path.""") -flags.DEFINE_boolean("clear_devices", True, - """Whether to remove device specifications.""") -flags.DEFINE_string("initializer_nodes", "", "comma separated list of " - "initializer nodes to run before freezing.") -flags.DEFINE_string("variable_names_blacklist", "", "comma separated " - "list of variables to skip converting to constants ") - - -def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, - output_node_names, restore_op_name, filename_tensor_name, - output_graph, clear_devices, initializer_nodes): +FLAGS = None + + +def freeze_graph(input_graph, + input_saver, + input_binary, + input_checkpoint, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph, + clear_devices, + initializer_nodes, + variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" if not gfile.Exists(input_graph): @@ -124,8 +113,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, if initializer_nodes: sess.run(initializer_nodes) - variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if - FLAGS.variable_names_blacklist else None) + variable_names_blacklist = (variable_names_blacklist.split(",") if + variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, @@ -141,8 +130,73 @@ def main(unused_args): freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, FLAGS.input_checkpoint, FLAGS.output_node_names, FLAGS.restore_op_name, FLAGS.filename_tensor_name, - FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes) + FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, + FLAGS.variable_names_blacklist) if __name__ == "__main__": - app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--input_graph", + type=str, + default="", + help="TensorFlow \'GraphDef\' file to load.") + parser.add_argument( + "--input_saver", + type=str, + default="", + help="TensorFlow saver file to load.") + parser.add_argument( + "--input_checkpoint", + type=str, + default="", + help="TensorFlow variables file to load.") + parser.add_argument( + "--output_graph", + type=str, + default="", + help="Output \'GraphDef\' file name.") + parser.add_argument( + "--input_binary", + nargs="?", + const=True, + type="bool", + default=False, + help="Whether the input files are in binary format.") + parser.add_argument( + "--output_node_names", + type=str, + default="", + help="The name of the output nodes, comma separated.") + parser.add_argument( + "--restore_op_name", + type=str, + default="save/restore_all", + help="The name of the master restore operator.") + parser.add_argument( + "--filename_tensor_name", + type=str, + default="save/Const:0", + help="The name of the tensor holding the save path.") + parser.add_argument( + "--clear_devices", + nargs="?", + const=True, + type="bool", + default=True, + help="Whether to remove device specifications.") + parser.add_argument( + "--initializer_nodes", + type=str, + default="", + help="comma separated list of initializer nodes to run before freezing.") + parser.add_argument( + "--variable_names_blacklist", + type=str, + default="", + help="""\ + comma separated list of variables to skip converting to constants\ + """) + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index 285f586a7d..8f934eedbd 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -17,20 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import sys from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app -from tensorflow.python.platform import flags -FLAGS = flags.FLAGS -flags.DEFINE_string("file_name", "", "Checkpoint filename") -flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect") -flags.DEFINE_bool("all_tensors", "False", - "If True, print the values of all the tensors.") +FLAGS = None -def print_tensors_in_checkpoint_file(file_name, tensor_name): +def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes @@ -41,10 +37,11 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name): Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. + all_tensors: Boolean indicating whether to print all tensors. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) - if FLAGS.all_tensors: + if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name: ", key) @@ -67,8 +64,26 @@ def main(unused_argv): "[--tensor_name=tensor_to_print]") sys.exit(1) else: - print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name) + print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, + FLAGS.all_tensors) if __name__ == "__main__": - app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--file_name", type=str, default="", help="Checkpoint filename") + parser.add_argument( + "--tensor_name", + type=str, + default="", + help="Name of the tensor to inspect") + parser.add_argument( + "--all_tensors", + nargs="?", + const=True, + type="bool", + default=False, + help="If True, print the values of all the tensors.") + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py index 165b84673c..9657e85077 100644 --- a/tensorflow/python/tools/optimize_for_inference.py +++ b/tensorflow/python/tools/optimize_for_inference.py @@ -55,7 +55,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import os +import sys from google.protobuf import text_format @@ -63,22 +65,10 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.platform import app -from tensorflow.python.platform import flags as flags_lib from tensorflow.python.platform import gfile from tensorflow.python.tools import optimize_for_inference_lib -flags = flags_lib -FLAGS = flags.FLAGS -flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""") -flags.DEFINE_string("output", "", """File to save the output graph to.""") -flags.DEFINE_string("input_names", "", """Input node names, comma separated.""") -flags.DEFINE_string("output_names", "", - """Output node names, comma separated.""") -flags.DEFINE_boolean("frozen_graph", True, - """If true, the input graph is a binary frozen GraphDef - file; if false, it is a text GraphDef proto file.""") -flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum, - """The AttrValue enum to use for placeholders.""") +FLAGS = None def main(unused_args): @@ -110,4 +100,42 @@ def main(unused_args): if __name__ == "__main__": - app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--input", + type=str, + default="", + help="TensorFlow \'GraphDef\' file to load.") + parser.add_argument( + "--output", + type=str, + default="", + help="File to save the output graph to.") + parser.add_argument( + "--input_names", + type=str, + default="", + help="Input node names, comma separated.") + parser.add_argument( + "--output_names", + type=str, + default="", + help="Output node names, comma separated.") + parser.add_argument( + "--frozen_graph", + nargs="?", + const=True, + type="bool", + default=True, + help="""\ + If true, the input graph is a binary frozen GraphDef + file; if false, it is a text GraphDef proto file.\ + """) + parser.add_argument( + "--placeholder_type_enum", + type=int, + default=dtypes.float32.as_datatype_enum, + help="The AttrValue enum to use for placeholders.") + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/tools/strip_unused.py b/tensorflow/python/tools/strip_unused.py index d6088e7c68..1788a1eca1 100644 --- a/tensorflow/python/tools/strip_unused.py +++ b/tensorflow/python/tools/strip_unused.py @@ -41,25 +41,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + from tensorflow.python.framework import dtypes from tensorflow.python.platform import app -from tensorflow.python.platform import flags from tensorflow.python.tools import strip_unused_lib -FLAGS = flags.FLAGS -flags.DEFINE_string("input_graph", "", - """TensorFlow 'GraphDef' file to load.""") -flags.DEFINE_boolean("input_binary", False, - """Whether the input files are in binary format.""") -flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""") -flags.DEFINE_boolean("output_binary", True, - """Whether to write a binary format graph.""") -flags.DEFINE_string("input_node_names", "", - """The name of the input nodes, comma separated.""") -flags.DEFINE_string("output_node_names", "", - """The name of the output nodes, comma separated.""") -flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum, - """The AttrValue enum to use for placeholders.""") +FLAGS = None def main(unused_args): @@ -72,5 +61,47 @@ def main(unused_args): FLAGS.placeholder_type_enum) -if __name__ == "__main__": - app.run() +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--input_graph', + type=str, + default='', + help='TensorFlow \'GraphDef\' file to load.') + parser.add_argument( + '--input_binary', + nargs='?', + const=True, + type='bool', + default=False, + help='Whether the input files are in binary format.') + parser.add_argument( + '--output_graph', + type=str, + default='', + help='Output \'GraphDef\' file name.') + parser.add_argument( + '--output_binary', + nargs='?', + const=True, + type='bool', + default=True, + help='Whether to write a binary format graph.') + parser.add_argument( + '--input_node_names', + type=str, + default='', + help='The name of the input nodes, comma separated.') + parser.add_argument( + '--output_node_names', + type=str, + default='', + help='The name of the output nodes, comma separated.') + parser.add_argument( + '--placeholder_type_enum', + type=int, + default=dtypes.float32.as_datatype_enum, + help='The AttrValue enum to use for placeholders.') + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) -- cgit v1.2.3