aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/checkpoint_reader.cc
diff options
context:
space:
mode:
authorGravatar Zongheng Yang <zongheng@google.com>2016-10-14 21:10:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 22:18:59 -0700
commit1e324a0f2a67cfa651677bd381bf1bf2adc3e2f8 (patch)
tree28e83ed5affb60d00426540458d930e79b5e67dc /tensorflow/c/checkpoint_reader.cc
parentc53f1a9930b85110bb3164695f853722ce293d29 (diff)
checkpoint_reader: fix VarToShapeMap V2 impl.
This change makes it adhere to the original semantics: all slices of a partitioned tensor are grouped under one entry. Change: 136229541
Diffstat (limited to 'tensorflow/c/checkpoint_reader.cc')
-rw-r--r--tensorflow/c/checkpoint_reader.cc29
1 files changed, 26 insertions, 3 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index 2ac3f75e5b..17b3f93193 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -14,10 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/checkpoint_reader.h"
+
+#include <unordered_set>
+
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
namespace tensorflow {
@@ -98,15 +102,34 @@ void CheckpointReader::GetTensor(
TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() {
CHECK(v2_reader_ != nullptr);
CHECK(v2_reader_->status().ok());
+
+ // First pass: filters out the entries of the slices.
+ std::unordered_set<string> filtered_keys;
+ BundleEntryProto entry;
v2_reader_->Seek(kHeaderEntryKey);
+ for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
+ CHECK(entry.ParseFromArray(v2_reader_->value().data(),
+ v2_reader_->value().size()))
+ << entry.InitializationErrorString();
+ for (int i = 0; i < entry.slices_size(); ++i) {
+ const auto& slice_proto = entry.slices(i);
+ CHECK(filtered_keys
+ .insert(EncodeTensorNameSlice(
+ v2_reader_->key().ToString() /* full var's name */,
+ TensorSlice(slice_proto)))
+ .second);
+ }
+ }
+ // Second pass: adds the entries, ignoring the filtered keys.
TensorSliceReader::VarToShapeMap* var_to_shape_map =
new TensorSliceReader::VarToShapeMap;
- BundleEntryProto entry;
+ v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
+ if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
- v2_reader_->value().size()));
- if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
+ v2_reader_->value().size()))
+ << entry.InitializationErrorString();
(*var_to_shape_map)[v2_reader_->key().ToString()] =
TensorShape(entry.shape());
}