diff options
author | 2018-09-28 17:04:41 -0700 | |
---|---|---|
committer | 2018-09-28 17:11:28 -0700 | |
commit | 6d354f6bd686d748d02039f26197f590b817b8c3 (patch) | |
tree | e224fd065c7a92644956b547376118e2bf194954 /tensorflow/core/kernels | |
parent | 3c01aa2b00ee4c3fda412b23da39fd0894c04cf7 (diff) |
[tf.data] Use `std::make_shared` as appropriate in `ParallelMapIterator`.
PiperOrigin-RevId: 215019058
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/data/parallel_map_iterator.cc | 40 |
1 files changed, 19 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 8393024c51..da067a4e6f 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -106,18 +106,17 @@ class ParallelMapIterator : public DatasetBaseIterator { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"), invocation_results_.size())); for (size_t i = 0; i < invocation_results_.size(); i++) { - std::shared_ptr<InvocationResult> result = invocation_results_[i]; - TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + const auto& result = *(invocation_results_[i]); + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("invocation_results[", i, "].size")), - result->return_values.size())); - for (size_t j = 0; j < result->return_values.size(); j++) { - TF_RETURN_IF_ERROR( - writer->WriteTensor(full_name(strings::StrCat( - "invocation_results[", i, "][", j, "]")), - result->return_values[j])); + result.return_values.size())); + for (size_t j = 0; j < result.return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("invocation_results[", i, "][", j, "]")), + result.return_values[j])); } - if (result->end_of_input) { + if (result.end_of_input) { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name( strings::StrCat("invocation_results[", i, "].end_of_input")), @@ -135,9 +134,9 @@ class ParallelMapIterator : public DatasetBaseIterator { TF_RETURN_IF_ERROR(reader->ReadScalar( full_name("invocation_results.size"), &invocation_results_size)); for (size_t i = 0; i < invocation_results_size; i++) { - std::shared_ptr<InvocationResult> result(new InvocationResult()); - invocation_results_.push_back(result); - TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + invocation_results_.push_back(std::make_shared<InvocationResult>()); + auto& result = *invocation_results_.back(); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); size_t num_return_values; { int64 size; @@ -153,17 +152,16 @@ class ParallelMapIterator : public DatasetBaseIterator { ": ", size, " is not a valid value of type size_t.")); } } - result->return_values.reserve(num_return_values); + result.return_values.reserve(num_return_values); for (size_t j = 0; j < num_return_values; j++) { - result->return_values.emplace_back(); - TF_RETURN_IF_ERROR( - reader->ReadTensor(full_name(strings::StrCat( - "invocation_results[", i, "][", j, "]")), - &result->return_values.back())); + result.return_values.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("invocation_results[", i, "][", j, "]")), + &result.return_values.back())); } - result->end_of_input = reader->Contains(full_name( + result.end_of_input = reader->Contains(full_name( strings::StrCat("invocation_results[", i, "].end_of_input"))); - result->notification.Notify(); + result.notification.Notify(); } return Status::OK(); } @@ -259,7 +257,7 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } while (!busy()) { - invocation_results_.emplace_back(new InvocationResult()); + invocation_results_.push_back(std::make_shared<InvocationResult>()); new_calls.push_back(invocation_results_.back()); num_calls_++; } |