diff options
author | Zongheng Yang <zongheng@google.com> | 2016-09-28 08:26:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-28 09:35:36 -0700 |
commit | ecdf0b202a2bfcff7985e62da727397bd8c67a91 (patch) | |
tree | b81d13a2f910cb0f34794b2e46ba08f6415124e1 /tensorflow/c/checkpoint_reader.cc | |
parent | 63a7a30e6bd091f87be1de2305c6d882d68ba6a8 (diff) |
TF Checkpoint V2: make CheckpointReader work with the V2 format.
If the same checkpoint prefix identifies both a V1 checkpoint and a V2
checkpoint on disk, the V2 version takes priority -- which matches the same
behavior as the RestoreV2 op.
Typical usage:
$ bazel run tensorflow/python/tools:inspect_checkpoint -- --file_name=<V2 ckpt prefix>
Other changes: add DebugString() and Contains() to BundleReader.
Change: 134543092
Diffstat (limited to 'tensorflow/c/checkpoint_reader.cc')
-rw-r--r-- | tensorflow/c/checkpoint_reader.cc | 64 |
1 files changed, 58 insertions, 6 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index dd9cb22559..23a42cffbb 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -27,11 +27,25 @@ class TensorSliceReader; CheckpointReader::CheckpointReader(const string& filename, TF_Status* out_status) - : reader_(nullptr), var_to_shape_map_ptr_(nullptr) { - reader_ = new TensorSliceReader(filename); - if (!reader_->status().ok()) { - Set_TF_Status_from_Status(out_status, reader_->status()); + : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) { + // Depending on whether this is a V2 ckpt, initializes "reader_" or + // "v2_reader_". + std::vector<string> v2_path; + if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && + !v2_path.empty()) { + v2_reader_ = + new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */); + if (!v2_reader_->status().ok()) { + Set_TF_Status_from_Status(out_status, v2_reader_->status()); + return; + } + var_to_shape_map_ptr_ = BuildV2VarToShapeMap(); } else { + reader_ = new TensorSliceReader(filename); + if (!reader_->status().ok()) { + Set_TF_Status_from_Status(out_status, reader_->status()); + return; + } var_to_shape_map_ptr_ = new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()); } @@ -43,7 +57,10 @@ CheckpointReader::~CheckpointReader() { } bool CheckpointReader::HasTensor(const string& name) const { - return reader_->HasTensor(name, nullptr, nullptr); + if (reader_ != nullptr) { + return reader_->HasTensor(name, nullptr, nullptr); + } + return v2_reader_->Contains(name); } const TensorSliceReader::VarToShapeMap& @@ -53,7 +70,42 @@ CheckpointReader::GetVariableToShapeMap() const { } const string CheckpointReader::DebugString() const { - return reader_->DebugString(); + if (reader_ != nullptr) return reader_->DebugString(); + return v2_reader_->DebugString(); +} + +void CheckpointReader::GetTensor( + const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor, + TF_Status* out_status) const { + Status status; + if (reader_ != nullptr) { + status = reader_->GetTensor(name, out_tensor); + } else { + std::unique_ptr<Tensor> tensor(new Tensor); + status = v2_reader_->Lookup(name, tensor.get()); + if (status.ok()) std::swap(*out_tensor, tensor); + } + if (!status.ok()) { + Set_TF_Status_from_Status(out_status, status); + } +} + +TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { + CHECK(v2_reader_ != nullptr); + CHECK(v2_reader_->status().ok()); + v2_reader_->Seek(kHeaderEntryKey); + + TensorSliceReader::VarToShapeMap* var_to_shape_map = + new TensorSliceReader::VarToShapeMap; + BundleEntryProto entry; + for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { + CHECK(entry.ParseFromArray(v2_reader_->value().data(), + v2_reader_->value().size())); + if (entry.slices_size() > 0) continue; // Slice of some partitioned var. + (*var_to_shape_map)[v2_reader_->key().ToString()] = + TensorShape(entry.shape()); + } + return var_to_shape_map; // Owned by caller. } } // namespace checkpoint |