aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/checkpoint_reader.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-05 06:57:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 07:01:19 -0700
commita8c5d5fe011e796593d20c74d8b927c014a27c89 (patch)
tree54d20a415ac25c9d5d8d8f909133b6db2b7a9289 /tensorflow/c/checkpoint_reader.h
parent220515bffdf1df5379a7f8921f5a12deb2e0dee7 (diff)
Expose data type information in checkpoint reader.
PiperOrigin-RevId: 171147196
Diffstat (limited to 'tensorflow/c/checkpoint_reader.h')
-rw-r--r--tensorflow/c/checkpoint_reader.h17
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h
index 470c8d1e10..4de1300a7f 100644
--- a/tensorflow/c/checkpoint_reader.h
+++ b/tensorflow/c/checkpoint_reader.h
@@ -44,10 +44,14 @@ class CheckpointReader {
bool HasTensor(const string& name) const;
const string DebugString() const;
- // Returns a map from variable names to its shape. Slices of a partitioned
+ // Returns a map from variable names to their shapes. Slices of a partitioned
// tensor are combined into a single entry.
const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const;
+ // Returns a map from variable names to their data types. Slices of a
+ // partitioned tensor are combined into a single entry.
+ const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const;
+
// Attempts to look up the tensor named "name" and stores the found result in
// "out_tensor".
void GetTensor(const string& name,
@@ -55,14 +59,19 @@ class CheckpointReader {
TF_Status* out_status) const;
private:
- // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller.
+ // Uses "v2_reader_" to build "var name -> shape" and "var name -> data type"
+ // maps; both owned by caller.
// REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()".
- std::unique_ptr<TensorSliceReader::VarToShapeMap> BuildV2VarToShapeMap();
+ std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
+ std::unique_ptr<TensorSliceReader::VarToDataTypeMap> >
+ BuildV2VarMaps();
// Invariant: exactly one of "reader_" and "v2_reader_" is non-null.
std::unique_ptr<TensorSliceReader> reader_;
std::unique_ptr<BundleReader> v2_reader_;
- std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map_ptr_;
+
+ std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map_;
+ std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map_;
TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader);
};