aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-01-06 12:07:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-06 12:28:00 -0800
commit2b351f224df81121cdcf8131d84be0e3f43d407c (patch)
tree1d79981f2f4104b1a171baad22aaa6ff796a4c2a
parent4b3d59a771252506cc34e66ebf2cd93be2564229 (diff)
Convert tf.flags usage to argparse. Move use of FLAGS globals into main() only.
Change: 143799731
-rw-r--r--tensorflow/python/client/notebook.py33
-rw-r--r--tensorflow/python/debug/cli/offline_analyzer.py41
-rw-r--r--tensorflow/python/debug/examples/debug_errors.py37
-rw-r--r--tensorflow/python/debug/examples/debug_fibonacci.py29
-rw-r--r--tensorflow/python/debug/examples/debug_mnist.py51
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py60
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py25
-rw-r--r--tensorflow/python/saved_model/example/saved_model_half_plus_two.py25
-rw-r--r--tensorflow/python/tools/freeze_graph.py118
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py35
-rw-r--r--tensorflow/python/tools/optimize_for_inference.py56
-rw-r--r--tensorflow/python/tools/strip_unused.py65
12 files changed, 416 insertions, 159 deletions
diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py
index b18fb6f889..8babe35b32 100644
--- a/tensorflow/python/client/notebook.py
+++ b/tensorflow/python/client/notebook.py
@@ -31,27 +31,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import os
import socket
import sys
+from tensorflow.python.platform import app
+
# pylint: disable=g-import-not-at-top
# Official recommended way of turning on fast protocol buffers as of 10/21/14
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string(
- "password", None,
- "Password to require. If set, the server will allow public access."
- " Only used if notebook config file does not exist.")
+FLAGS = None
-flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks",
- "root location where to store notebooks")
ORIG_ARGV = sys.argv
# Main notebook process calls itself with argv[1]="kernel" to start kernel
@@ -108,6 +102,21 @@ def main(unused_argv):
if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--password",
+ type=str,
+ default=None,
+ help="""\
+ Password to require. If set, the server will allow public access. Only
+ used if notebook config file does not exist.\
+ """)
+ parser.add_argument(
+ "--notebook_dir",
+ type=str,
+ default="experimental/brain/notebooks",
+ help="root location where to store notebooks")
+
# When the user starts the main notebook process, we don't touch sys.argv.
# When the main process launches kernel subprocesses, it writes all flags
# to a tmpfile and sets --flagfile to that tmpfile, so for kernel
@@ -118,4 +127,6 @@ if __name__ == "__main__":
# Drop everything except --flagfile.
sys.argv = ([sys.argv[0]] +
[x for x in sys.argv[1:] if x.startswith("--flagfile")])
- app.run()
+
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/cli/offline_analyzer.py b/tensorflow/python/debug/cli/offline_analyzer.py
index d91d82ff95..3f78c54732 100644
--- a/tensorflow/python/debug/cli/offline_analyzer.py
+++ b/tensorflow/python/debug/cli/offline_analyzer.py
@@ -17,23 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import sys
# Google-internal import(s).
+
from tensorflow.python.debug import debug_data
from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("dump_dir", "", "tfdbg dump directory to load")
-flags.DEFINE_string("ui_type", "curses",
- "Command-line user interface type (curses | readline)")
-flags.DEFINE_boolean(
- "log_usage", True, "Whether the usage of this tool is to be logged")
-flags.DEFINE_boolean(
- "validate_graph", True,
- "Whether the dumped tensors will be validated against the GraphDefs")
def main(_):
@@ -58,4 +49,30 @@ def main(_):
if __name__ == "__main__":
- app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--dump_dir", type=str, default="", help="tfdbg dump directory to load")
+ parser.add_argument(
+ "--log_usage",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=True,
+ help="Whether the usage of this tool is to be logged")
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses"
+ help="Command-line user interface type (curses | readline)")
+ parser.add_argument(
+ "--validate_graph",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=True,
+ help="""\
+ Whether the dumped tensors will be validated against the GraphDefs\
+ """)
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/debug_errors.py b/tensorflow/python/debug/examples/debug_errors.py
index 6369853658..855c7cb684 100644
--- a/tensorflow/python/debug/examples/debug_errors.py
+++ b/tensorflow/python/debug/examples/debug_errors.py
@@ -17,20 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
+import sys
+
import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug
-flags = tf.app.flags
-FLAGS = flags.FLAGS
-flags.DEFINE_string("error", "shape_mismatch", "Type of the error to generate "
- "(shape_mismatch | uninitialized_variable | no_error).")
-flags.DEFINE_string("ui_type", "curses",
- "Command-line user interface type (curses | readline)")
-flags.DEFINE_boolean("debug", False,
- "Use debugger to track down bad values during training")
-
def main(_):
sess = tf.Session()
@@ -60,4 +54,27 @@ def main(_):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--error",
+ type=str,
+ default="shape_mismatch",
+ help="""\
+ Type of the error to generate (shape_mismatch | uninitialized_variable |
+ no_error).\
+ """)
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses"
+ help="Command-line user interface type (curses | readline)")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/debug_fibonacci.py b/tensorflow/python/debug/examples/debug_fibonacci.py
index 99643316be..14722ecd08 100644
--- a/tensorflow/python/debug/examples/debug_fibonacci.py
+++ b/tensorflow/python/debug/examples/debug_fibonacci.py
@@ -17,19 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
+import sys
+
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python import debug as tf_debug
-flags = tf.app.flags
-FLAGS = flags.FLAGS
-flags.DEFINE_integer("tensor_size", 30,
- "Size of tensor. E.g., if the value is 30, the tensors "
- "will have shape [30, 30].")
-flags.DEFINE_integer("length", 20,
- "Length of the fibonacci sequence to compute.")
+FLAGS = None
def main(_):
@@ -54,4 +51,20 @@ def main(_):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--tensor_size",
+ type=int,
+ default=30,
+ help="""\
+ Size of tensor. E.g., if the value is 30, the tensors will have shape
+ [30, 30].\
+ """)
+ parser.add_argument(
+ "--length",
+ type=int,
+ default=20,
+ help="Length of the fibonacci sequence to compute.")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/debug_mnist.py b/tensorflow/python/debug/examples/debug_mnist.py
index 526b85762d..1b999aff7c 100644
--- a/tensorflow/python/debug/examples/debug_mnist.py
+++ b/tensorflow/python/debug/examples/debug_mnist.py
@@ -24,22 +24,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
+import sys
+
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python import debug as tf_debug
-flags = tf.app.flags
-FLAGS = flags.FLAGS
-flags.DEFINE_integer("max_steps", 10, "Number of steps to run trainer.")
-flags.DEFINE_integer("train_batch_size", 100,
- "Batch size used during training.")
-flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
-flags.DEFINE_string("data_dir", "/tmp/mnist_data", "Directory for storing data")
-flags.DEFINE_string("ui_type", "curses",
- "Command-line user interface type (curses | readline)")
-flags.DEFINE_boolean("debug", False,
- "Use debugger to track down bad values during training")
IMAGE_SIZE = 28
HIDDEN_SIZE = 500
@@ -137,4 +129,39 @@ def main(_):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--max_steps",
+ type=int,
+ default=10,
+ help="Number of steps to run trainer.")
+ parser.add_argument(
+ "--train_batch_size",
+ type=int,
+ default=100,
+ help="Batch size used during training.")
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=0.025,
+ help="Initial learning rate.")
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default="/tmp/mnist_data",
+ help="Directory for storing data")
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses"
+ help="Command-line user interface type (curses | readline)")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 9c61ed5bff..ea1f588f57 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import os
+import sys
import tempfile
import numpy as np
@@ -26,33 +28,26 @@ import tensorflow as tf
from tensorflow.python import debug as tf_debug
-flags = tf.app.flags
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/iris_data",
- "Directory to save the training and test data in.")
-flags.DEFINE_string("model_dir", "", "Directory to save the trained model in.")
-flags.DEFINE_integer("train_steps", 10, "Number of steps to run trainer.")
-flags.DEFINE_string("ui_type", "curses",
- "Command-line user interface type (curses | readline)")
-flags.DEFINE_boolean("debug", False,
- "Use debugger to track down bad values during training")
# URLs to download data sets from, if necessary.
IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv"
IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv"
-def maybe_download_data():
+def maybe_download_data(data_dir):
"""Download data sets if necessary.
+ Args:
+ data_dir: Path to where data should be downloaded.
+
Returns:
Paths to the training and test data files.
"""
- if not os.path.isdir(FLAGS.data_dir):
- os.makedirs(FLAGS.data_dir)
+ if not os.path.isdir(data_dir):
+ os.makedirs(data_dir)
- training_data_path = os.path.join(FLAGS.data_dir,
+ training_data_path = os.path.join(data_dir,
os.path.basename(IRIS_TRAINING_DATA_URL))
if not os.path.isfile(training_data_path):
train_file = open(training_data_path, "wt")
@@ -61,8 +56,7 @@ def maybe_download_data():
print("Training data are downloaded to %s" % train_file.name)
- test_data_path = os.path.join(FLAGS.data_dir,
- os.path.basename(IRIS_TEST_DATA_URL))
+ test_data_path = os.path.join(data_dir, os.path.basename(IRIS_TEST_DATA_URL))
if not os.path.isfile(test_data_path):
test_file = open(test_data_path, "wt")
urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name)
@@ -74,7 +68,7 @@ def maybe_download_data():
def main(_):
- training_data_path, test_data_path = maybe_download_data()
+ training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir)
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
@@ -115,4 +109,34 @@ def main(_):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default="/tmp/iris_data",
+ help="Directory to save the training and test data in.")
+ parser.add_argument(
+ "--model_dir",
+ type=str,
+ default="",
+ help="Directory to save the trained model in.")
+ parser.add_argument(
+ "--train_steps",
+ type=int,
+ default=10,
+ help="Number of steps to run trainer.")
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses"
+ help="Command-line user interface type (curses | readline)")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 377ccb7c9b..031c9668ab 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import collections
import os.path
import sys
@@ -31,12 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import docs
from tensorflow.python.framework import framework_lib
-
-tf.flags.DEFINE_string("out_dir", None,
- "Directory to which docs should be written.")
-tf.flags.DEFINE_boolean("print_hidden_regex", False,
- "Dump a regular expression matching any hidden symbol")
-FLAGS = tf.flags.FLAGS
+FLAGS = None
PREFIX_TEXT = """
@@ -309,4 +305,19 @@ def main(unused_argv):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--out_dir",
+ type=str,
+ default=None,
+ help="Directory to which docs should be written.")
+ parser.add_argument(
+ "--print_hidden_regex",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Dump a regular expression matching any hidden symbol")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
index d65bc0dfa8..a71747da10 100644
--- a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
+++ b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
@@ -31,7 +31,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import os
+import sys
+
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
@@ -42,13 +45,7 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.util import compat
-tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two",
- "Directory where to ouput SavedModel.")
-tf.app.flags.DEFINE_string("output_dir_pbtxt",
- "/tmp/saved_model_half_plus_two_pbtxt",
- "Directory where to ouput the text format of "
- "SavedModel.")
-FLAGS = tf.flags.FLAGS
+FLAGS = None
def _write_assets(assets_directory, assets_filename):
@@ -172,4 +169,16 @@ def main(_):
if __name__ == "__main__":
- tf.app.run()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="/tmp/saved_model_half_plus_two",
+ help="Directory where to ouput SavedModel.")
+ parser.add_argument(
+ "--output_dir_pbtxt",
+ type=str,
+ default="/tmp/saved_model_half_plus_two_pbtxt",
+ help="Directory where to ouput the text format of SavedModel.")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 3ccb1b782c..bdd59eeb6b 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -37,6 +37,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
+import sys
+
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
@@ -45,37 +48,23 @@ from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string("input_graph", "",
- """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_string("input_saver", "", """TensorFlow saver file to load.""")
-flags.DEFINE_string("input_checkpoint", "",
- """TensorFlow variables file to load.""")
-flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
-flags.DEFINE_boolean("input_binary", False,
- """Whether the input files are in binary format.""")
-flags.DEFINE_string("output_node_names", "",
- """The name of the output nodes, comma separated.""")
-flags.DEFINE_string("restore_op_name", "save/restore_all",
- """The name of the master restore operator.""")
-flags.DEFINE_string("filename_tensor_name", "save/Const:0",
- """The name of the tensor holding the save path.""")
-flags.DEFINE_boolean("clear_devices", True,
- """Whether to remove device specifications.""")
-flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
- "initializer nodes to run before freezing.")
-flags.DEFINE_string("variable_names_blacklist", "", "comma separated "
- "list of variables to skip converting to constants ")
-
-
-def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
- output_node_names, restore_op_name, filename_tensor_name,
- output_graph, clear_devices, initializer_nodes):
+FLAGS = None
+
+
+def freeze_graph(input_graph,
+ input_saver,
+ input_binary,
+ input_checkpoint,
+ output_node_names,
+ restore_op_name,
+ filename_tensor_name,
+ output_graph,
+ clear_devices,
+ initializer_nodes,
+ variable_names_blacklist=""):
"""Converts all variables in a graph and checkpoint into constants."""
if not gfile.Exists(input_graph):
@@ -124,8 +113,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
if initializer_nodes:
sess.run(initializer_nodes)
- variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
- FLAGS.variable_names_blacklist else None)
+ variable_names_blacklist = (variable_names_blacklist.split(",") if
+ variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
@@ -141,8 +130,73 @@ def main(unused_args):
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
FLAGS.input_checkpoint, FLAGS.output_node_names,
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
- FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
+ FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
+ FLAGS.variable_names_blacklist)
if __name__ == "__main__":
- app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--input_graph",
+ type=str,
+ default="",
+ help="TensorFlow \'GraphDef\' file to load.")
+ parser.add_argument(
+ "--input_saver",
+ type=str,
+ default="",
+ help="TensorFlow saver file to load.")
+ parser.add_argument(
+ "--input_checkpoint",
+ type=str,
+ default="",
+ help="TensorFlow variables file to load.")
+ parser.add_argument(
+ "--output_graph",
+ type=str,
+ default="",
+ help="Output \'GraphDef\' file name.")
+ parser.add_argument(
+ "--input_binary",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=False,
+ help="Whether the input files are in binary format.")
+ parser.add_argument(
+ "--output_node_names",
+ type=str,
+ default="",
+ help="The name of the output nodes, comma separated.")
+ parser.add_argument(
+ "--restore_op_name",
+ type=str,
+ default="save/restore_all",
+ help="The name of the master restore operator.")
+ parser.add_argument(
+ "--filename_tensor_name",
+ type=str,
+ default="save/Const:0",
+ help="The name of the tensor holding the save path.")
+ parser.add_argument(
+ "--clear_devices",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=True,
+ help="Whether to remove device specifications.")
+ parser.add_argument(
+ "--initializer_nodes",
+ type=str,
+ default="",
+ help="comma separated list of initializer nodes to run before freezing.")
+ parser.add_argument(
+ "--variable_names_blacklist",
+ type=str,
+ default="",
+ help="""\
+ comma separated list of variables to skip converting to constants\
+ """)
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 285f586a7d..8f934eedbd 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -17,20 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import sys
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-FLAGS = flags.FLAGS
-flags.DEFINE_string("file_name", "", "Checkpoint filename")
-flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
-flags.DEFINE_bool("all_tensors", "False",
- "If True, print the values of all the tensors.")
+FLAGS = None
-def print_tensors_in_checkpoint_file(file_name, tensor_name):
+def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
@@ -41,10 +37,11 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name):
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
+ all_tensors: Boolean indicating whether to print all tensors.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
- if FLAGS.all_tensors:
+ if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
@@ -67,8 +64,26 @@ def main(unused_argv):
"[--tensor_name=tensor_to_print]")
sys.exit(1)
else:
- print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)
+ print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
+ FLAGS.all_tensors)
if __name__ == "__main__":
- app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--file_name", type=str, default="", help="Checkpoint filename")
+ parser.add_argument(
+ "--tensor_name",
+ type=str,
+ default="",
+ help="Name of the tensor to inspect")
+ parser.add_argument(
+ "--all_tensors",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=False,
+ help="If True, print the values of all the tensors.")
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py
index 165b84673c..9657e85077 100644
--- a/tensorflow/python/tools/optimize_for_inference.py
+++ b/tensorflow/python/tools/optimize_for_inference.py
@@ -55,7 +55,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import os
+import sys
from google.protobuf import text_format
@@ -63,22 +65,10 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.platform import app
-from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
-flags = flags_lib
-FLAGS = flags.FLAGS
-flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_string("output", "", """File to save the output graph to.""")
-flags.DEFINE_string("input_names", "", """Input node names, comma separated.""")
-flags.DEFINE_string("output_names", "",
- """Output node names, comma separated.""")
-flags.DEFINE_boolean("frozen_graph", True,
- """If true, the input graph is a binary frozen GraphDef
- file; if false, it is a text GraphDef proto file.""")
-flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
- """The AttrValue enum to use for placeholders.""")
+FLAGS = None
def main(unused_args):
@@ -110,4 +100,42 @@ def main(unused_args):
if __name__ == "__main__":
- app.run()
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--input",
+ type=str,
+ default="",
+ help="TensorFlow \'GraphDef\' file to load.")
+ parser.add_argument(
+ "--output",
+ type=str,
+ default="",
+ help="File to save the output graph to.")
+ parser.add_argument(
+ "--input_names",
+ type=str,
+ default="",
+ help="Input node names, comma separated.")
+ parser.add_argument(
+ "--output_names",
+ type=str,
+ default="",
+ help="Output node names, comma separated.")
+ parser.add_argument(
+ "--frozen_graph",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=True,
+ help="""\
+ If true, the input graph is a binary frozen GraphDef
+ file; if false, it is a text GraphDef proto file.\
+ """)
+ parser.add_argument(
+ "--placeholder_type_enum",
+ type=int,
+ default=dtypes.float32.as_datatype_enum,
+ help="The AttrValue enum to use for placeholders.")
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/tools/strip_unused.py b/tensorflow/python/tools/strip_unused.py
index d6088e7c68..1788a1eca1 100644
--- a/tensorflow/python/tools/strip_unused.py
+++ b/tensorflow/python/tools/strip_unused.py
@@ -41,25 +41,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
+import sys
+
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
from tensorflow.python.tools import strip_unused_lib
-FLAGS = flags.FLAGS
-flags.DEFINE_string("input_graph", "",
- """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_boolean("input_binary", False,
- """Whether the input files are in binary format.""")
-flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
-flags.DEFINE_boolean("output_binary", True,
- """Whether to write a binary format graph.""")
-flags.DEFINE_string("input_node_names", "",
- """The name of the input nodes, comma separated.""")
-flags.DEFINE_string("output_node_names", "",
- """The name of the output nodes, comma separated.""")
-flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
- """The AttrValue enum to use for placeholders.""")
+FLAGS = None
def main(unused_args):
@@ -72,5 +61,47 @@ def main(unused_args):
FLAGS.placeholder_type_enum)
-if __name__ == "__main__":
- app.run()
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--input_graph',
+ type=str,
+ default='',
+ help='TensorFlow \'GraphDef\' file to load.')
+ parser.add_argument(
+ '--input_binary',
+ nargs='?',
+ const=True,
+ type='bool',
+ default=False,
+ help='Whether the input files are in binary format.')
+ parser.add_argument(
+ '--output_graph',
+ type=str,
+ default='',
+ help='Output \'GraphDef\' file name.')
+ parser.add_argument(
+ '--output_binary',
+ nargs='?',
+ const=True,
+ type='bool',
+ default=True,
+ help='Whether to write a binary format graph.')
+ parser.add_argument(
+ '--input_node_names',
+ type=str,
+ default='',
+ help='The name of the input nodes, comma separated.')
+ parser.add_argument(
+ '--output_node_names',
+ type=str,
+ default='',
+ help='The name of the output nodes, comma separated.')
+ parser.add_argument(
+ '--placeholder_type_enum',
+ type=int,
+ default=dtypes.float32.as_datatype_enum,
+ help='The AttrValue enum to use for placeholders.')
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)