diff options
author | Sherry Moore <sherrym@google.com> | 2016-04-01 09:40:26 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-01 10:42:50 -0700 |
commit | 8baea485d6c3fafff8fd2d14cb0367318efcbd14 (patch) | |
tree | 4530533fdb22aa5255edf4832e92773611786771 /tensorflow/core/util/tensor_slice_reader.h | |
parent | 72debff6d683412dd564123ec3095fbb33b1cfca (diff) |
Added GetTensor() to return tensor content for a single variable from
checkpoint file. To use:
reader = tf.train.NewPyCheckpointReader(your-checkpoint-path)
print(reader.GetTensor(your-tensor-name))
Change: 118792812
Diffstat (limited to 'tensorflow/core/util/tensor_slice_reader.h')
-rw-r--r-- | tensorflow/core/util/tensor_slice_reader.h | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h index 5dd0b2f919..7db75223e1 100644 --- a/tensorflow/core/util/tensor_slice_reader.h +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -23,6 +23,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/framework/types.pb.h" @@ -96,6 +97,11 @@ class TensorSliceReader { return tensors_; } + // Returns value for one tensor. Only single slice checkpoints are supported + // at the moment. + Status GetTensor(const string& name, + std::unique_ptr<tensorflow::Tensor>* out_tensor) const; + typedef std::unordered_map<string, TensorShape> VarToShapeMap; // Returns a map from tensor name to shape. VarToShapeMap GetVariableToShapeMap() const; |