diff options
Diffstat (limited to 'tensorflow/python/tools/inspect_checkpoint.py')
-rw-r--r-- | tensorflow/python/tools/inspect_checkpoint.py | 23 |
1 files changed, 5 insertions, 18 deletions
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index 8716058e61..47a74e5abf 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -29,8 +29,7 @@ from tensorflow.python.platform import flags FLAGS = None -def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, - all_tensor_names): +def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes @@ -42,16 +41,14 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. - all_tensor_names: Boolean indicating whether to print all tensor names. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) - if all_tensors or all_tensor_names: + if all_tensors: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): print("tensor_name: ", key) - if all_tensors: - print(reader.get_tensor(key)) + print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: @@ -107,14 +104,11 @@ def parse_numpy_printoption(kv_str): def main(unused_argv): if not FLAGS.file_name: print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " - "[--tensor_name=tensor_to_print] " - "[--all_tensors] " - "[--all_tensor_names] " - "[--printoptions]") + "[--tensor_name=tensor_to_print]") sys.exit(1) else: print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, - FLAGS.all_tensors, FLAGS.all_tensor_names) + FLAGS.all_tensors) if __name__ == "__main__": @@ -137,13 +131,6 @@ if __name__ == "__main__": default=False, help="If True, print the values of all the tensors.") parser.add_argument( - "--all_tensor_names", - nargs="?", - const=True, - type="bool", - default=False, - help="If True, print the names of all the tensors.") - parser.add_argument( "--printoptions", nargs="*", type=parse_numpy_printoption, |