aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-28 17:04:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 17:11:28 -0700
commit6d354f6bd686d748d02039f26197f590b817b8c3 (patch)
treee224fd065c7a92644956b547376118e2bf194954 /tensorflow/core/kernels
parent3c01aa2b00ee4c3fda412b23da39fd0894c04cf7 (diff)
[tf.data] Use `std::make_shared` as appropriate in `ParallelMapIterator`.
PiperOrigin-RevId: 215019058
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc40
1 files changed, 19 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 8393024c51..da067a4e6f 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -106,18 +106,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ const auto& result = *(invocation_results_[i]);
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(
- writer->WriteTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
+ result.return_values.size()));
+ for (size_t j = 0; j < result.return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result.return_values[j]));
}
- if (result->end_of_input) {
+ if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
@@ -135,9 +134,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
+ auto& result = *invocation_results_.back();
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
@@ -153,17 +152,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
": ", size, " is not a valid value of type size_t."));
}
}
- result->return_values.reserve(num_return_values);
+ result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ result.return_values.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ &result.return_values.back()));
}
- result->end_of_input = reader->Contains(full_name(
+ result.end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
+ result.notification.Notify();
}
return Status::OK();
}
@@ -259,7 +257,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
while (!busy()) {
- invocation_results_.emplace_back(new InvocationResult());
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}