aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_reader.h
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2016-04-01 09:40:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-01 10:42:50 -0700
commit8baea485d6c3fafff8fd2d14cb0367318efcbd14 (patch)
tree4530533fdb22aa5255edf4832e92773611786771 /tensorflow/core/util/tensor_slice_reader.h
parent72debff6d683412dd564123ec3095fbb33b1cfca (diff)
Added GetTensor() to return tensor content for a single variable from
checkpoint file. To use: reader = tf.train.NewPyCheckpointReader(your-checkpoint-path) print(reader.GetTensor(your-tensor-name)) Change: 118792812
Diffstat (limited to 'tensorflow/core/util/tensor_slice_reader.h')
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
index 5dd0b2f919..7db75223e1 100644
--- a/tensorflow/core/util/tensor_slice_reader.h
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/framework/types.pb.h"
@@ -96,6 +97,11 @@ class TensorSliceReader {
return tensors_;
}
+ // Returns value for one tensor. Only single slice checkpoints are supported
+ // at the moment.
+ Status GetTensor(const string& name,
+ std::unique_ptr<tensorflow::Tensor>* out_tensor) const;
+
typedef std::unordered_map<string, TensorShape> VarToShapeMap;
// Returns a map from tensor name to shape.
VarToShapeMap GetVariableToShapeMap() const;