aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2018-08-27 14:29:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 14:35:58 -0700
commit3b4df1b62c5b2c1302ebf23a9040cc749f0dd23d (patch)
tree7230dd6862fbfb498dfc446f3645405491570e5e /tensorflow/python/tools
parentfb9a2fbfe461020b7ae167f97832c8a2f060319d (diff)
adding args and returns docstrings to freeze_graph public functions to make them more user friendly
PiperOrigin-RevId: 210432021
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/freeze_graph.py64
1 files changed, 62 insertions, 2 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index c7f414c5dc..893309f35a 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -89,7 +89,37 @@ def freeze_graph_with_def_protos(input_graph_def,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
- """Converts all variables in a graph and checkpoint into constants."""
+ """Converts all variables in a graph and checkpoint into constants.
+
+ Args:
+ input_graph_def: A `GraphDef`.
+ input_saver_def: A `SaverDef` (optional).
+ input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ output_node_names: The name(s) of the output nodes, comma separated.
+ restore_op_name: Unused.
+ filename_tensor_name: Unused.
+ output_graph: String where to write the frozen `GraphDef`.
+ clear_devices: A Bool whether to remove device specifications.
+ initializer_nodes: Comma separated string of initializer nodes to run before
+ freezing.
+ variable_names_whitelist: The set of variable names to convert (optional, by
+ default, all variables are converted).
+ variable_names_blacklist: The set of variable names to omit converting
+ to constants (optional).
+ input_meta_graph_def: A `MetaGraphDef` (optional),
+ input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
+ and variables (optional).
+ saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
+ load, in string format (optional).
+ checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
+ or saver_pb2.SaverDef.V2)
+
+ Returns:
+ Location of the output_graph_def.
+ """
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
@@ -271,7 +301,37 @@ def freeze_graph(input_graph,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
- """Converts all variables in a graph and checkpoint into constants."""
+ """Converts all variables in a graph and checkpoint into constants.
+
+ Args:
+ input_graph: A `GraphDef` file to load.
+ input_saver: A TensorFlow Saver file.
+ input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
+ input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ output_node_names: The name(s) of the output nodes, comma separated.
+ restore_op_name: Unused.
+ filename_tensor_name: Unused.
+ output_graph: String where to write the frozen `GraphDef`.
+ clear_devices: A Bool whether to remove device specifications.
+ initializer_nodes: Comma separated list of initializer nodes to run before
+ freezing.
+ variable_names_whitelist: The set of variable names to convert (optional, by
+ default, all variables are converted),
+ variable_names_blacklist: The set of variable names to omit converting
+ to constants (optional).
+ input_meta_graph: A `MetaGraphDef` file to load (optional).
+ input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
+ variables (optional).
+ saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
+ load, in string format.
+ checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
+ or saver_pb2.SaverDef.V2).
+ Returns:
+ String that is the location of frozen GraphDef.
+ """
input_graph_def = None
if input_saved_model_dir:
input_graph_def = saved_model_utils.get_meta_graph_def(