aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_reader.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-05 06:57:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 07:01:19 -0700
commita8c5d5fe011e796593d20c74d8b927c014a27c89 (patch)
tree54d20a415ac25c9d5d8d8f909133b6db2b7a9289 /tensorflow/core/util/tensor_slice_reader.h
parent220515bffdf1df5379a7f8921f5a12deb2e0dee7 (diff)
Expose data type information in checkpoint reader.
PiperOrigin-RevId: 171147196
Diffstat (limited to 'tensorflow/core/util/tensor_slice_reader.h')
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
index 5932d59a15..4bb2b24615 100644
--- a/tensorflow/core/util/tensor_slice_reader.h
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -103,9 +103,14 @@ class TensorSliceReader {
std::unique_ptr<tensorflow::Tensor>* out_tensor) const;
typedef std::unordered_map<string, TensorShape> VarToShapeMap;
+ typedef std::unordered_map<string, DataType> VarToDataTypeMap;
+
// Returns a map from tensor name to shape.
VarToShapeMap GetVariableToShapeMap() const;
+ // Returns a map from tensor name to data type.
+ VarToDataTypeMap GetVariableToDataTypeMap() const;
+
// Returns a string containing names and shapes of all the tensors.
const string DebugString() const;