diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_defun_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_defun_op.cc | 98 |
1 files changed, 77 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 3c562fc7f3..b87d61ee44 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/util/reffed_status_callback.h" @@ -60,26 +62,43 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} + Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { + // Validates inputs and gets the size of their leading dimension. + *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != *batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } + } + return Status::OK(); + } + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - int64 batch_size = ctx->input(0).dim_size(0); + int64 batch_size; + OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done); + // Inputs auto* args = new std::vector<Tensor>; auto* arg_shapes = new std::vector<TensorShape>; + + // Create a copy because every `Compute` may have different output shapes. + auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_); arg_shapes->reserve(ctx->num_inputs()); args->reserve(ctx->num_inputs()); + auto* mu = new mutex; + for (size_t i = 0; i < ctx->num_inputs(); ++i) { args->push_back(ctx->input(i)); arg_shapes->push_back(ctx->input(i).shape()); arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension - OP_REQUIRES_ASYNC( - ctx, batch_size == ctx->input(i).dim_size(0), - errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size, - "."), - done); } // Outputs @@ -87,10 +106,14 @@ class MapDefunOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done); for (size_t i = 0; i < output_types().size(); ++i) { - Tensor* out = nullptr; - TensorShape output_shape = output_shapes_.at(i); - output_shape.InsertDim(0, batch_size); - OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done); + if (output_shapes_.at(i).IsFullyDefined()) { + Tensor* out = nullptr; + TensorShape output_shape; + output_shapes_.at(i).AsTensorShape(&output_shape); + output_shape.InsertDim(0, batch_size); + OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), + done); + } } SetRunOptions(ctx, &opts_, false); @@ -98,15 +121,19 @@ class MapDefunOp : public AsyncOpKernel { // Run loop StatusCallback callback = std::bind( [](OpKernelContext* ctx, std::vector<Tensor>* args, - std::vector<TensorShape>* arg_shapes, OpOutputList* output, - DoneCallback& done, const Status& status) { + std::vector<TensorShape>* arg_shapes, + std::vector<PartialTensorShape>* output_shapes, OpOutputList* output, + mutex* mu, DoneCallback& done, const Status& status) { delete args; delete arg_shapes; delete output; + delete output_shapes; + delete mu; ctx->SetStatus(status); done(); }, - ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1); + ctx, args, arg_shapes, output_shapes, output, mu, std::move(done), + std::placeholders::_1); auto* refcounted = new ReffedStatusCallback(std::move(callback)); @@ -114,9 +141,11 @@ class MapDefunOp : public AsyncOpKernel { // Start from i = 1 because refcounted is initialized with refcount = 1 refcounted->Ref(); } + for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) { - auto* call_frame = - new MapFunctionCallFrame(*args, *arg_shapes, output, this, i); + auto* call_frame = new MapFunctionCallFrame( + *args, *arg_shapes, output_shapes, mu, output, this, i, + static_cast<size_t>(batch_size)); CancellationManager* c_mgr = new CancellationManager; opts_.cancellation_manager = c_mgr; ctx->function_library()->Run( @@ -133,18 +162,23 @@ class MapDefunOp : public AsyncOpKernel { private: FunctionLibraryRuntime::Handle func_handle_; FunctionLibraryRuntime::Options opts_; - std::vector<TensorShape> output_shapes_; + std::vector<PartialTensorShape> output_shapes_; class MapFunctionCallFrame : public CallFrameInterface { public: MapFunctionCallFrame(const std::vector<Tensor>& args, const std::vector<TensorShape>& arg_shapes, - OpOutputList* output, OpKernel* kernel, size_t iter) + std::vector<PartialTensorShape>* output_shapes, + mutex* output_shapes_mutex, OpOutputList* output, + OpKernel* kernel, size_t iter, size_t batch_size) : args_(args), arg_shapes_(arg_shapes), + output_shapes_(output_shapes), + output_shapes_mutex_(output_shapes_mutex), output_(output), kernel_(kernel), - iter_(iter) {} + iter_(iter), + batch_size_(batch_size) {} ~MapFunctionCallFrame() override {} @@ -182,15 +216,37 @@ class MapDefunOp : public AsyncOpKernel { "output: ", index); } + { // Locking scope + mutex_lock l(*output_shapes_mutex_); + if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) { + return errors::InvalidArgument( + "Mismatch in function retval shape, ", val.shape(), + ", and expected output shape,", + output_shapes_->at(index).DebugString(), "."); + } + if (!output_shapes_->at(index).IsFullyDefined()) { + // Given val, we have new information about the output shape at + // this index. Store the shape and allocate the output accordingly. + output_shapes_->at(index) = val.shape(); + + Tensor* out = nullptr; + TensorShape actual_shape = val.shape(); + actual_shape.InsertDim(0, batch_size_); + TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out)); + } + } return batch_util::CopyElementToSlice(val, (*output_)[index], iter_); } private: const std::vector<Tensor>& args_; const std::vector<TensorShape>& arg_shapes_; + std::vector<PartialTensorShape>* output_shapes_; + mutex* output_shapes_mutex_; OpOutputList* output_; const OpKernel* kernel_; const size_t iter_; + const size_t batch_size_; }; }; |