diff options
author | 2017-11-08 15:24:05 -0800 | |
---|---|---|
committer | 2017-11-10 16:14:37 -0800 | |
commit | 2ba34173fad0d5b7d986baeb8171bdc6afdcd7bb (patch) | |
tree | d64ae9296db4c2aae776f9785d5e4e73d75e245f | |
parent | 6f7cf68cb0cf0728ed3f030ade20c439ceadccdf (diff) |
Add a --all_tensor_names option, which is useful if I only want to know all tensor names. It is especially useful in cases whether some of the tensors has huge size. Also update the usage description.
PiperOrigin-RevId: 175074541
-rw-r--r-- | tensorflow/python/tools/inspect_checkpoint.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index 47a74e5abf..8716058e61 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -29,7 +29,8 @@ from tensorflow.python.platform import flags FLAGS = None -def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): +def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, + all_tensor_names): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes @@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. + all_tensor_names: Boolean indicating whether to print all tensor names. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) - if all_tensors: + if all_tensors or all_tensor_names: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): print("tensor_name: ", key) - print(reader.get_tensor(key)) + if all_tensors: + print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: @@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str): def main(unused_argv): if not FLAGS.file_name: print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " - "[--tensor_name=tensor_to_print]") + "[--tensor_name=tensor_to_print] " + "[--all_tensors] " + "[--all_tensor_names] " + "[--printoptions]") sys.exit(1) else: print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, - FLAGS.all_tensors) + FLAGS.all_tensors, FLAGS.all_tensor_names) if __name__ == "__main__": @@ -131,6 +137,13 @@ if __name__ == "__main__": default=False, help="If True, print the values of all the tensors.") parser.add_argument( + "--all_tensor_names", + nargs="?", + const=True, + type="bool", + default=False, + help="If True, print the names of all the tensors.") + parser.add_argument( "--printoptions", nargs="*", type=parse_numpy_printoption, |