diff options
author | Zongheng Yang <zongheng@google.com> | 2016-09-28 08:26:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-28 09:35:36 -0700 |
commit | ecdf0b202a2bfcff7985e62da727397bd8c67a91 (patch) | |
tree | b81d13a2f910cb0f34794b2e46ba08f6415124e1 /tensorflow/c/checkpoint_reader.h | |
parent | 63a7a30e6bd091f87be1de2305c6d882d68ba6a8 (diff) |
TF Checkpoint V2: make CheckpointReader work with the V2 format.
If the same checkpoint prefix identifies both a V1 checkpoint and a V2
checkpoint on disk, the V2 version takes priority -- which matches the same
behavior as the RestoreV2 op.
Typical usage:
$ bazel run tensorflow/python/tools:inspect_checkpoint -- --file_name=<V2 ckpt prefix>
Other changes: add DebugString() and Contains() to BundleReader.
Change: 134543092
Diffstat (limited to 'tensorflow/c/checkpoint_reader.h')
-rw-r--r-- | tensorflow/c/checkpoint_reader.h | 31 |
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index fb06d6d864..6d27bdc375 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { @@ -28,10 +29,15 @@ namespace checkpoint { class TensorSliceReader; -// A wrapper around checkpoint::TensorSliceReader that is more easily SWIG -// wrapped for other languages. +// A wrapper around BundleReader (for V2 checkpoints) and +// checkpoint::TensorSliceReader (for V1), that is more easily SWIG wrapped for +// other languages. +// +// The class currently only interacts with single-slice (i.e., non-partitioned) +// variables. class CheckpointReader { public: + CheckpointReader(const string& filepattern, TF_Status* out_status); ~CheckpointReader(); bool HasTensor(const string& name) const; @@ -39,20 +45,21 @@ class CheckpointReader { const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const; + // Attempts to look up the tensor named "name" and stores the found result in + // "out_tensor". void GetTensor(const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor, - TF_Status* out_status) const { - Status status = reader_->GetTensor(name, out_tensor); - if (!status.ok()) { - Set_TF_Status_from_Status(out_status, status); - } - } - - CheckpointReader(const string& filepattern, TF_Status* out_status); + TF_Status* out_status) const; private: - TensorSliceReader* reader_; // Owned - TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned + // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller. + // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". + TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap(); + + // Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr. + TensorSliceReader* reader_; // Owned. + BundleReader* v2_reader_; // Owned. + TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned. TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader); }; |