diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-05 03:46:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-05 03:49:51 -0700 |
commit | 220515bffdf1df5379a7f8921f5a12deb2e0dee7 (patch) | |
tree | 7bace952b5ff55d3eaa62ec4153fc28d3d28bc1d /tensorflow | |
parent | f6b15b08bbedc500549b0793b236bc90289d07dc (diff) |
Replace owning raw pointers with unique pointers
PiperOrigin-RevId: 171132628
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/c/checkpoint_reader.cc | 26 | ||||
-rw-r--r-- | tensorflow/c/checkpoint_reader.h | 15 |
2 files changed, 18 insertions, 23 deletions
diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index e7b9bca5b5..fc86e92f3b 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; @@ -37,30 +36,24 @@ CheckpointReader::CheckpointReader(const string& filename, std::vector<string> v2_path; if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && !v2_path.empty()) { - v2_reader_ = - new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */); + v2_reader_.reset( + new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); if (!v2_reader_->status().ok()) { Set_TF_Status_from_Status(out_status, v2_reader_->status()); return; } var_to_shape_map_ptr_ = BuildV2VarToShapeMap(); } else { - reader_ = new TensorSliceReader(filename); + reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { Set_TF_Status_from_Status(out_status, reader_->status()); return; } - var_to_shape_map_ptr_ = - new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()); + var_to_shape_map_ptr_.reset( + new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); } } -CheckpointReader::~CheckpointReader() { - delete var_to_shape_map_ptr_; - delete reader_; - delete v2_reader_; -} - bool CheckpointReader::HasTensor(const string& name) const { if (reader_ != nullptr) { return reader_->HasTensor(name, nullptr, nullptr); @@ -100,7 +93,8 @@ void CheckpointReader::GetTensor( } } -TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { +std::unique_ptr<TensorSliceReader::VarToShapeMap> +CheckpointReader::BuildV2VarToShapeMap() { CHECK(v2_reader_ != nullptr); CHECK(v2_reader_->status().ok()); @@ -123,8 +117,8 @@ TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { } // Second pass: adds the entries, ignoring the filtered keys. - TensorSliceReader::VarToShapeMap* var_to_shape_map = - new TensorSliceReader::VarToShapeMap; + std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map( + new TensorSliceReader::VarToShapeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; @@ -134,7 +128,7 @@ TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { (*var_to_shape_map)[v2_reader_->key().ToString()] = TensorShape(entry.shape()); } - return var_to_shape_map; // Owned by caller. + return var_to_shape_map; } } // namespace checkpoint diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 1124416380..470c8d1e10 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_CHECKPOINT_READER_H #define TENSORFLOW_C_CHECKPOINT_READER_H +#include <memory> +#include <string> + #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" @@ -24,7 +27,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; @@ -38,7 +40,6 @@ class TensorSliceReader; class CheckpointReader { public: CheckpointReader(const string& filepattern, TF_Status* out_status); - ~CheckpointReader(); bool HasTensor(const string& name) const; const string DebugString() const; @@ -56,12 +57,12 @@ class CheckpointReader { private: // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller. // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". - TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap(); + std::unique_ptr<TensorSliceReader::VarToShapeMap> BuildV2VarToShapeMap(); - // Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr. - TensorSliceReader* reader_; // Owned. - BundleReader* v2_reader_; // Owned. - TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned. + // Invariant: exactly one of "reader_" and "v2_reader_" is non-null. + std::unique_ptr<TensorSliceReader> reader_; + std::unique_ptr<BundleReader> v2_reader_; + std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map_ptr_; TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader); }; |