diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-05 06:57:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-05 07:01:19 -0700 |
commit | a8c5d5fe011e796593d20c74d8b927c014a27c89 (patch) | |
tree | 54d20a415ac25c9d5d8d8f909133b6db2b7a9289 /tensorflow/c/checkpoint_reader.h | |
parent | 220515bffdf1df5379a7f8921f5a12deb2e0dee7 (diff) |
Expose data type information in checkpoint reader.
PiperOrigin-RevId: 171147196
Diffstat (limited to 'tensorflow/c/checkpoint_reader.h')
-rw-r--r-- | tensorflow/c/checkpoint_reader.h | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 470c8d1e10..4de1300a7f 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -44,10 +44,14 @@ class CheckpointReader { bool HasTensor(const string& name) const; const string DebugString() const; - // Returns a map from variable names to its shape. Slices of a partitioned + // Returns a map from variable names to their shapes. Slices of a partitioned // tensor are combined into a single entry. const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const; + // Returns a map from variable names to their data types. Slices of a + // partitioned tensor are combined into a single entry. + const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const; + // Attempts to look up the tensor named "name" and stores the found result in // "out_tensor". void GetTensor(const string& name, @@ -55,14 +59,19 @@ class CheckpointReader { TF_Status* out_status) const; private: - // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller. + // Uses "v2_reader_" to build "var name -> shape" and "var name -> data type" + // maps; both owned by caller. // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". - std::unique_ptr<TensorSliceReader::VarToShapeMap> BuildV2VarToShapeMap(); + std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>, + std::unique_ptr<TensorSliceReader::VarToDataTypeMap> > + BuildV2VarMaps(); // Invariant: exactly one of "reader_" and "v2_reader_" is non-null. std::unique_ptr<TensorSliceReader> reader_; std::unique_ptr<BundleReader> v2_reader_; - std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map_ptr_; + + std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map_; + std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map_; TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader); }; |