diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_and_batch_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_and_batch_dataset_op.cc | 38 |
1 files changed, 28 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index f45a239793..bae56828dc 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -324,6 +324,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } private: + // BatchResult encapsulates the output batch, as well as anciliary + // metadata required to execute the fused map-and-batch operation. struct BatchResult { explicit BatchResult(int64 batch_size) { end_of_input = false; @@ -331,11 +333,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { num_elements = 0; output_allocated = false; status = Status::OK(); + status_offset = -1; } - void UpdateStatus(const Status& s) { - mutex_lock l(mu); - status.Update(s); + // UpdateStatus updates the batch's aggregate Status. + // + // In order to ensure that exactly the first non-OK status is returned + // (required to make the behavior is observably identical to a + // sequential execution of map followed by batch), we must also keep + // track of the offset into the batch that produced `s`. + void UpdateStatus(const Status& s, int64 offset) { + if (TF_PREDICT_FALSE(!s.ok())) { + mutex_lock l(mu); + if (status.ok() || offset < status_offset) { + status = s; + status_offset = offset; + } + } } mutex mu; @@ -344,6 +358,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor> output; bool output_allocated GUARDED_BY(mu); Status status GUARDED_BY(mu); + int64 status_offset GUARDED_BY(mu); // Counts the number of outstanding calls for this batch. int64 num_calls; // access guarded by owner's mutex }; @@ -379,7 +394,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::shared_ptr<std::vector<Tensor>> return_values = std::make_shared<std::vector<Tensor>>(); auto done = [this, ctx, result, return_values, offset](Status status) { - result->UpdateStatus(status); + result->UpdateStatus(status, offset); if (status.ok()) { EnsureOutputAllocated(ctx, result, return_values); for (size_t i = 0; i < return_values->size(); ++i) { @@ -389,11 +404,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { (batch->NumElements() / batch->dim_size(0))) { TensorShape batch_shape = batch->shape(); batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does " - "not match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); + result->UpdateStatus( + errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements " + "does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString()), + offset); break; } // TODO(mrry): Add a version of DoParallelConcat that allows us to @@ -402,7 +420,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status copy_status = ::tensorflow::functor::DoParallelConcat( *dataset()->device_, tensor, offset, batch); if (!copy_status.ok()) { - result->UpdateStatus(copy_status); + result->UpdateStatus(copy_status, offset); break; } } |