aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/inspect_checkpoint.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/inspect_checkpoint.py')
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py23
1 files changed, 5 insertions, 18 deletions
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 8716058e61..47a74e5abf 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -29,8 +29,7 @@ from tensorflow.python.platform import flags
FLAGS = None
-def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
- all_tensor_names):
+def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
@@ -42,16 +41,14 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
- all_tensor_names: Boolean indicating whether to print all tensor names.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
- if all_tensors or all_tensor_names:
+ if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
print("tensor_name: ", key)
- if all_tensors:
- print(reader.get_tensor(key))
+ print(reader.get_tensor(key))
elif not tensor_name:
print(reader.debug_string().decode("utf-8"))
else:
@@ -107,14 +104,11 @@ def parse_numpy_printoption(kv_str):
def main(unused_argv):
if not FLAGS.file_name:
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
- "[--tensor_name=tensor_to_print] "
- "[--all_tensors] "
- "[--all_tensor_names] "
- "[--printoptions]")
+ "[--tensor_name=tensor_to_print]")
sys.exit(1)
else:
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
- FLAGS.all_tensors, FLAGS.all_tensor_names)
+ FLAGS.all_tensors)
if __name__ == "__main__":
@@ -137,13 +131,6 @@ if __name__ == "__main__":
default=False,
help="If True, print the values of all the tensors.")
parser.add_argument(
- "--all_tensor_names",
- nargs="?",
- const=True,
- type="bool",
- default=False,
- help="If True, print the names of all the tensors.")
- parser.add_argument(
"--printoptions",
nargs="*",
type=parse_numpy_printoption,