diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-05 06:57:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-05 07:01:19 -0700 |
commit | a8c5d5fe011e796593d20c74d8b927c014a27c89 (patch) | |
tree | 54d20a415ac25c9d5d8d8f909133b6db2b7a9289 /tensorflow/c/checkpoint_reader.cc | |
parent | 220515bffdf1df5379a7f8921f5a12deb2e0dee7 (diff) |
Expose data type information in checkpoint reader.
PiperOrigin-RevId: 171147196
Diffstat (limited to 'tensorflow/c/checkpoint_reader.cc')
-rw-r--r-- | tensorflow/c/checkpoint_reader.cc | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index fc86e92f3b..b1f7bdaa54 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/checkpoint_reader.h" #include <unordered_set> +#include <utility> #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,7 +31,10 @@ class TensorSliceReader; CheckpointReader::CheckpointReader(const string& filename, TF_Status* out_status) - : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) { + : reader_(nullptr), + v2_reader_(nullptr), + var_to_shape_map_(nullptr), + var_to_data_type_map_(nullptr) { // Depending on whether this is a V2 ckpt, initializes "reader_" or // "v2_reader_". std::vector<string> v2_path; @@ -42,15 +46,19 @@ CheckpointReader::CheckpointReader(const string& filename, Set_TF_Status_from_Status(out_status, v2_reader_->status()); return; } - var_to_shape_map_ptr_ = BuildV2VarToShapeMap(); + auto result = BuildV2VarMaps(); + var_to_shape_map_.swap(result.first); + var_to_data_type_map_.swap(result.second); } else { reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { Set_TF_Status_from_Status(out_status, reader_->status()); return; } - var_to_shape_map_ptr_.reset( + var_to_shape_map_.reset( new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); + var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap( + reader_->GetVariableToDataTypeMap())); } } @@ -63,8 +71,14 @@ bool CheckpointReader::HasTensor(const string& name) const { const TensorSliceReader::VarToShapeMap& CheckpointReader::GetVariableToShapeMap() const { - CHECK(var_to_shape_map_ptr_); - return *var_to_shape_map_ptr_; + CHECK(var_to_shape_map_); + return *var_to_shape_map_; +} + +const TensorSliceReader::VarToDataTypeMap& +CheckpointReader::GetVariableToDataTypeMap() const { + CHECK(var_to_data_type_map_); + return *var_to_data_type_map_; } const string CheckpointReader::DebugString() const { @@ -93,8 +107,9 @@ void CheckpointReader::GetTensor( } } -std::unique_ptr<TensorSliceReader::VarToShapeMap> -CheckpointReader::BuildV2VarToShapeMap() { +std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>, + std::unique_ptr<TensorSliceReader::VarToDataTypeMap>> +CheckpointReader::BuildV2VarMaps() { CHECK(v2_reader_ != nullptr); CHECK(v2_reader_->status().ok()); @@ -119,16 +134,21 @@ CheckpointReader::BuildV2VarToShapeMap() { // Second pass: adds the entries, ignoring the filtered keys. std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map( new TensorSliceReader::VarToShapeMap); + std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map( + new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - (*var_to_shape_map)[v2_reader_->key().ToString()] = - TensorShape(entry.shape()); + string key = v2_reader_->key().ToString(); + (*var_to_shape_map)[key] = TensorShape(entry.shape()); + (*var_to_data_type_map)[key] = DataType(entry.dtype()); } - return var_to_shape_map; + // The returned pointers are owned by the caller. + return std::make_pair(std::move(var_to_shape_map), + std::move(var_to_data_type_map)); } } // namespace checkpoint |