aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/checkpoint_reader.h
diff options
context:
space:
mode:
authorGravatar Zongheng Yang <zongheng@google.com>2016-09-28 08:26:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-28 09:35:36 -0700
commitecdf0b202a2bfcff7985e62da727397bd8c67a91 (patch)
treeb81d13a2f910cb0f34794b2e46ba08f6415124e1 /tensorflow/c/checkpoint_reader.h
parent63a7a30e6bd091f87be1de2305c6d882d68ba6a8 (diff)
TF Checkpoint V2: make CheckpointReader work with the V2 format.
If the same checkpoint prefix identifies both a V1 checkpoint and a V2 checkpoint on disk, the V2 version takes priority -- which matches the same behavior as the RestoreV2 op. Typical usage: $ bazel run tensorflow/python/tools:inspect_checkpoint -- --file_name=<V2 ckpt prefix> Other changes: add DebugString() and Contains() to BundleReader. Change: 134543092
Diffstat (limited to 'tensorflow/c/checkpoint_reader.h')
-rw-r--r--tensorflow/c/checkpoint_reader.h31
1 files changed, 19 insertions, 12 deletions
diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h
index fb06d6d864..6d27bdc375 100644
--- a/tensorflow/c/checkpoint_reader.h
+++ b/tensorflow/c/checkpoint_reader.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
namespace tensorflow {
@@ -28,10 +29,15 @@ namespace checkpoint {
class TensorSliceReader;
-// A wrapper around checkpoint::TensorSliceReader that is more easily SWIG
-// wrapped for other languages.
+// A wrapper around BundleReader (for V2 checkpoints) and
+// checkpoint::TensorSliceReader (for V1), that is more easily SWIG wrapped for
+// other languages.
+//
+// The class currently only interacts with single-slice (i.e., non-partitioned)
+// variables.
class CheckpointReader {
public:
+ CheckpointReader(const string& filepattern, TF_Status* out_status);
~CheckpointReader();
bool HasTensor(const string& name) const;
@@ -39,20 +45,21 @@ class CheckpointReader {
const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const;
+ // Attempts to look up the tensor named "name" and stores the found result in
+ // "out_tensor".
void GetTensor(const string& name,
std::unique_ptr<tensorflow::Tensor>* out_tensor,
- TF_Status* out_status) const {
- Status status = reader_->GetTensor(name, out_tensor);
- if (!status.ok()) {
- Set_TF_Status_from_Status(out_status, status);
- }
- }
-
- CheckpointReader(const string& filepattern, TF_Status* out_status);
+ TF_Status* out_status) const;
private:
- TensorSliceReader* reader_; // Owned
- TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned
+ // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller.
+ // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()".
+ TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap();
+
+ // Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr.
+ TensorSliceReader* reader_; // Owned.
+ BundleReader* v2_reader_; // Owned.
+ TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned.
TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader);
};