aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-08 15:24:05 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:37 -0800
commit2ba34173fad0d5b7d986baeb8171bdc6afdcd7bb (patch)
treed64ae9296db4c2aae776f9785d5e4e73d75e245f
parent6f7cf68cb0cf0728ed3f030ade20c439ceadccdf (diff)
Add a --all_tensor_names option, which is useful if I only want to know all tensor names. It is especially useful in cases whether some of the tensors has huge size. Also update the usage description.
PiperOrigin-RevId: 175074541
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 47a74e5abf..8716058e61 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -29,7 +29,8 @@ from tensorflow.python.platform import flags
FLAGS = None
-def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
+def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
+ all_tensor_names):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
@@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
+ all_tensor_names: Boolean indicating whether to print all tensor names.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
- if all_tensors:
+ if all_tensors or all_tensor_names:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
print("tensor_name: ", key)
- print(reader.get_tensor(key))
+ if all_tensors:
+ print(reader.get_tensor(key))
elif not tensor_name:
print(reader.debug_string().decode("utf-8"))
else:
@@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str):
def main(unused_argv):
if not FLAGS.file_name:
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
- "[--tensor_name=tensor_to_print]")
+ "[--tensor_name=tensor_to_print] "
+ "[--all_tensors] "
+ "[--all_tensor_names] "
+ "[--printoptions]")
sys.exit(1)
else:
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
- FLAGS.all_tensors)
+ FLAGS.all_tensors, FLAGS.all_tensor_names)
if __name__ == "__main__":
@@ -131,6 +137,13 @@ if __name__ == "__main__":
default=False,
help="If True, print the values of all the tensors.")
parser.add_argument(
+ "--all_tensor_names",
+ nargs="?",
+ const=True,
+ type="bool",
+ default=False,
+ help="If True, print the names of all the tensors.")
+ parser.add_argument(
"--printoptions",
nargs="*",
type=parse_numpy_printoption,