aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/checkpoint_reader.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-05 06:57:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 07:01:19 -0700
commita8c5d5fe011e796593d20c74d8b927c014a27c89 (patch)
tree54d20a415ac25c9d5d8d8f909133b6db2b7a9289 /tensorflow/c/checkpoint_reader.cc
parent220515bffdf1df5379a7f8921f5a12deb2e0dee7 (diff)
Expose data type information in checkpoint reader.
PiperOrigin-RevId: 171147196
Diffstat (limited to 'tensorflow/c/checkpoint_reader.cc')
-rw-r--r--tensorflow/c/checkpoint_reader.cc40
1 files changed, 30 insertions, 10 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc
index fc86e92f3b..b1f7bdaa54 100644
--- a/tensorflow/c/checkpoint_reader.cc
+++ b/tensorflow/c/checkpoint_reader.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/checkpoint_reader.h"
#include <unordered_set>
+#include <utility>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -30,7 +31,10 @@ class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
- : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) {
+ : reader_(nullptr),
+ v2_reader_(nullptr),
+ var_to_shape_map_(nullptr),
+ var_to_data_type_map_(nullptr) {
// Depending on whether this is a V2 ckpt, initializes "reader_" or
// "v2_reader_".
std::vector<string> v2_path;
@@ -42,15 +46,19 @@ CheckpointReader::CheckpointReader(const string& filename,
Set_TF_Status_from_Status(out_status, v2_reader_->status());
return;
}
- var_to_shape_map_ptr_ = BuildV2VarToShapeMap();
+ auto result = BuildV2VarMaps();
+ var_to_shape_map_.swap(result.first);
+ var_to_data_type_map_.swap(result.second);
} else {
reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status());
return;
}
- var_to_shape_map_ptr_.reset(
+ var_to_shape_map_.reset(
new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
+ var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
+ reader_->GetVariableToDataTypeMap()));
}
}
@@ -63,8 +71,14 @@ bool CheckpointReader::HasTensor(const string& name) const {
const TensorSliceReader::VarToShapeMap&
CheckpointReader::GetVariableToShapeMap() const {
- CHECK(var_to_shape_map_ptr_);
- return *var_to_shape_map_ptr_;
+ CHECK(var_to_shape_map_);
+ return *var_to_shape_map_;
+}
+
+const TensorSliceReader::VarToDataTypeMap&
+CheckpointReader::GetVariableToDataTypeMap() const {
+ CHECK(var_to_data_type_map_);
+ return *var_to_data_type_map_;
}
const string CheckpointReader::DebugString() const {
@@ -93,8 +107,9 @@ void CheckpointReader::GetTensor(
}
}
-std::unique_ptr<TensorSliceReader::VarToShapeMap>
-CheckpointReader::BuildV2VarToShapeMap() {
+std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
+ std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
+CheckpointReader::BuildV2VarMaps() {
CHECK(v2_reader_ != nullptr);
CHECK(v2_reader_->status().ok());
@@ -119,16 +134,21 @@ CheckpointReader::BuildV2VarToShapeMap() {
// Second pass: adds the entries, ignoring the filtered keys.
std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
new TensorSliceReader::VarToShapeMap);
+ std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
+ new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
- (*var_to_shape_map)[v2_reader_->key().ToString()] =
- TensorShape(entry.shape());
+ string key = v2_reader_->key().ToString();
+ (*var_to_shape_map)[key] = TensorShape(entry.shape());
+ (*var_to_data_type_map)[key] = DataType(entry.dtype());
}
- return var_to_shape_map;
+ // The returned pointers are owned by the caller.
+ return std::make_pair(std::move(var_to_shape_map),
+ std::move(var_to_data_type_map));
}
} // namespace checkpoint