aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/checkpoint_reader.cc
diff options
context:
space:
mode:
authorGravatar Zongheng Yang <zongheng@google.com>2016-09-28 08:26:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-28 09:35:36 -0700
commitecdf0b202a2bfcff7985e62da727397bd8c67a91 (patch)
treeb81d13a2f910cb0f34794b2e46ba08f6415124e1 /tensorflow/c/checkpoint_reader.cc
parent63a7a30e6bd091f87be1de2305c6d882d68ba6a8 (diff)
TF Checkpoint V2: make CheckpointReader work with the V2 format.
If the same checkpoint prefix identifies both a V1 checkpoint and a V2 checkpoint on disk, the V2 version takes priority -- which matches the same behavior as the RestoreV2 op. Typical usage: $ bazel run tensorflow/python/tools:inspect_checkpoint -- --file_name=<V2 ckpt prefix> Other changes: add DebugString() and Contains() to BundleReader. Change: 134543092
Diffstat (limited to 'tensorflow/c/checkpoint_reader.cc')
-rw-r--r--tensorflow/c/checkpoint_reader.cc64
1 files changed, 58 insertions, 6 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index dd9cb22559..23a42cffbb 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -27,11 +27,25 @@ class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
- : reader_(nullptr), var_to_shape_map_ptr_(nullptr) {
- reader_ = new TensorSliceReader(filename);
- if (!reader_->status().ok()) {
- Set_TF_Status_from_Status(out_status, reader_->status());
+ : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) {
+ // Depending on whether this is a V2 ckpt, initializes "reader_" or
+ // "v2_reader_".
+ std::vector<string> v2_path;
+ if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
+ !v2_path.empty()) {
+ v2_reader_ =
+ new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */);
+ if (!v2_reader_->status().ok()) {
+ Set_TF_Status_from_Status(out_status, v2_reader_->status());
+ return;
+ }
+ var_to_shape_map_ptr_ = BuildV2VarToShapeMap();
} else {
+ reader_ = new TensorSliceReader(filename);
+ if (!reader_->status().ok()) {
+ Set_TF_Status_from_Status(out_status, reader_->status());
+ return;
+ }
var_to_shape_map_ptr_ =
new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap());
}
@@ -43,7 +57,10 @@ CheckpointReader::~CheckpointReader() {
}
bool CheckpointReader::HasTensor(const string& name) const {
- return reader_->HasTensor(name, nullptr, nullptr);
+ if (reader_ != nullptr) {
+ return reader_->HasTensor(name, nullptr, nullptr);
+ }
+ return v2_reader_->Contains(name);
}
const TensorSliceReader::VarToShapeMap&
@@ -53,7 +70,42 @@ CheckpointReader::GetVariableToShapeMap() const {
}
const string CheckpointReader::DebugString() const {
- return reader_->DebugString();
+ if (reader_ != nullptr) return reader_->DebugString();
+ return v2_reader_->DebugString();
+}
+
+void CheckpointReader::GetTensor(
+ const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
+ TF_Status* out_status) const {
+ Status status;
+ if (reader_ != nullptr) {
+ status = reader_->GetTensor(name, out_tensor);
+ } else {
+ std::unique_ptr<Tensor> tensor(new Tensor);
+ status = v2_reader_->Lookup(name, tensor.get());
+ if (status.ok()) std::swap(*out_tensor, tensor);
+ }
+ if (!status.ok()) {
+ Set_TF_Status_from_Status(out_status, status);
+ }
+}
+
+TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() {
+ CHECK(v2_reader_ != nullptr);
+ CHECK(v2_reader_->status().ok());
+ v2_reader_->Seek(kHeaderEntryKey);
+
+ TensorSliceReader::VarToShapeMap* var_to_shape_map =
+ new TensorSliceReader::VarToShapeMap;
+ BundleEntryProto entry;
+ for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
+ CHECK(entry.ParseFromArray(v2_reader_->value().data(),
+ v2_reader_->value().size()));
+ if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
+ (*var_to_shape_map)[v2_reader_->key().ToString()] =
+ TensorShape(entry.shape());
+ }
+ return var_to_shape_map; // Owned by caller.
}
} // namespace checkpoint