aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc38
1 files changed, 28 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index f45a239793..bae56828dc 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -324,6 +324,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
+ // BatchResult encapsulates the output batch, as well as anciliary
+ // metadata required to execute the fused map-and-batch operation.
struct BatchResult {
explicit BatchResult(int64 batch_size) {
end_of_input = false;
@@ -331,11 +333,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
num_elements = 0;
output_allocated = false;
status = Status::OK();
+ status_offset = -1;
}
- void UpdateStatus(const Status& s) {
- mutex_lock l(mu);
- status.Update(s);
+ // UpdateStatus updates the batch's aggregate Status.
+ //
+ // In order to ensure that exactly the first non-OK status is returned
+ // (required to make the behavior is observably identical to a
+ // sequential execution of map followed by batch), we must also keep
+ // track of the offset into the batch that produced `s`.
+ void UpdateStatus(const Status& s, int64 offset) {
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ mutex_lock l(mu);
+ if (status.ok() || offset < status_offset) {
+ status = s;
+ status_offset = offset;
+ }
+ }
}
mutex mu;
@@ -344,6 +358,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> output;
bool output_allocated GUARDED_BY(mu);
Status status GUARDED_BY(mu);
+ int64 status_offset GUARDED_BY(mu);
// Counts the number of outstanding calls for this batch.
int64 num_calls; // access guarded by owner's mutex
};
@@ -379,7 +394,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::shared_ptr<std::vector<Tensor>> return_values =
std::make_shared<std::vector<Tensor>>();
auto done = [this, ctx, result, return_values, offset](Status status) {
- result->UpdateStatus(status);
+ result->UpdateStatus(status, offset);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
for (size_t i = 0; i < return_values->size(); ++i) {
@@ -389,11 +404,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
(batch->NumElements() / batch->dim_size(0))) {
TensorShape batch_shape = batch->shape();
batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does "
- "not match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
+ result->UpdateStatus(
+ errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements "
+ "does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()),
+ offset);
break;
}
// TODO(mrry): Add a version of DoParallelConcat that allows us to
@@ -402,7 +420,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status copy_status = ::tensorflow::functor::DoParallelConcat(
*dataset()->device_, tensor, offset, batch);
if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
+ result->UpdateStatus(copy_status, offset);
break;
}
}