diff options
author | 2016-02-11 20:55:33 -0800 | |
---|---|---|
committer | 2016-02-12 08:03:36 -0800 | |
commit | 6b884f0bc592b146f64b408a684937b62bc6b6d3 (patch) | |
tree | a1eea57beac973833b28efc9c74554b003079fad /tensorflow/contrib/util | |
parent | 01806e89e5dd141ae456bfee2ca0e799fb1aa32d (diff) |
Added tool to inspect checkpoints.
Change: 114506562
Diffstat (limited to 'tensorflow/contrib/util')
-rw-r--r-- | tensorflow/contrib/util/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/contrib/util/inspect_checkpoint.cc | 50 |
2 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 6e517e117b..c0be2b9c14 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -7,6 +7,16 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +cc_binary( + name = "inspect_checkpoint", + srcs = ["inspect_checkpoint.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow", + ], +) + py_library( name = "util_py", srcs = glob(["**/*.py"]), diff --git a/tensorflow/contrib/util/inspect_checkpoint.cc b/tensorflow/contrib/util/inspect_checkpoint.cc new file mode 100644 index 0000000000..001f39a30f --- /dev/null +++ b/tensorflow/contrib/util/inspect_checkpoint.cc @@ -0,0 +1,50 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/tensor_slice_reader.h" + +namespace tensorflow { +namespace { + +int InspectCheckpoint(const string& in) { + tensorflow::checkpoint::TensorSliceReader reader( + in, tensorflow::checkpoint::OpenTableTensorSliceReader); + Status s = reader.status(); + if (!s.ok()) { + fprintf(stderr, "Unable to open the checkpoint file\n"); + return -1; + } + for (auto e : reader.Tensors()) { + fprintf(stdout, "%s %s\n", e.first.c_str(), + e.second->shape().DebugString().c_str()); + } + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 2) { + fprintf(stderr, "Usage: %s checkpoint_file\n", argv[0]); + exit(1); + } + return tensorflow::InspectCheckpoint(argv[1]); +} |