diff options
author | Zongheng Yang <zongheng@google.com> | 2016-10-14 21:10:19 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-14 22:18:59 -0700 |
commit | 1e324a0f2a67cfa651677bd381bf1bf2adc3e2f8 (patch) | |
tree | 28e83ed5affb60d00426540458d930e79b5e67dc /tensorflow/c/checkpoint_reader.cc | |
parent | c53f1a9930b85110bb3164695f853722ce293d29 (diff) |
checkpoint_reader: fix VarToShapeMap V2 impl.
This change makes it adhere to the original semantics: all slices of a
partitioned tensor are grouped under one entry.
Change: 136229541
Diffstat (limited to 'tensorflow/c/checkpoint_reader.cc')
-rw-r--r-- | tensorflow/c/checkpoint_reader.cc | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 2ac3f75e5b..17b3f93193 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -14,10 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/checkpoint_reader.h" + +#include <unordered_set> + #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { @@ -98,15 +102,34 @@ void CheckpointReader::GetTensor( TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { CHECK(v2_reader_ != nullptr); CHECK(v2_reader_->status().ok()); + + // First pass: filters out the entries of the slices. + std::unordered_set<string> filtered_keys; + BundleEntryProto entry; v2_reader_->Seek(kHeaderEntryKey); + for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { + CHECK(entry.ParseFromArray(v2_reader_->value().data(), + v2_reader_->value().size())) + << entry.InitializationErrorString(); + for (int i = 0; i < entry.slices_size(); ++i) { + const auto& slice_proto = entry.slices(i); + CHECK(filtered_keys + .insert(EncodeTensorNameSlice( + v2_reader_->key().ToString() /* full var's name */, + TensorSlice(slice_proto))) + .second); + } + } + // Second pass: adds the entries, ignoring the filtered keys. TensorSliceReader::VarToShapeMap* var_to_shape_map = new TensorSliceReader::VarToShapeMap; - BundleEntryProto entry; + v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { + if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), - v2_reader_->value().size())); - if (entry.slices_size() > 0) continue; // Slice of some partitioned var. + v2_reader_->value().size())) + << entry.InitializationErrorString(); (*var_to_shape_map)[v2_reader_->key().ToString()] = TensorShape(entry.shape()); } |