diff options
Diffstat (limited to 'tensorflow/core/kernels/data/map_defun_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/map_defun_op.cc | 234 |
1 files changed, 140 insertions, 94 deletions
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index b87d61ee44..6657f2b2b3 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -81,119 +81,167 @@ class MapDefunOp : public AsyncOpKernel { } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - int64 batch_size; - OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done); + ComputeOptions* compute_opts = nullptr; - // Inputs - auto* args = new std::vector<Tensor>; - auto* arg_shapes = new std::vector<TensorShape>; + OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done); - // Create a copy because every `Compute` may have different output shapes. - auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_); - arg_shapes->reserve(ctx->num_inputs()); - args->reserve(ctx->num_inputs()); + Status s = SetupOutputs(ctx, compute_opts); + if (!s.ok()) delete compute_opts; + OP_REQUIRES_OK_ASYNC(ctx, s, done); - auto* mu = new mutex; - - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args->push_back(ctx->input(i)); - arg_shapes->push_back(ctx->input(i).shape()); - arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension - } - - // Outputs - auto* output = new OpOutputList; - OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done); - - for (size_t i = 0; i < output_types().size(); ++i) { - if (output_shapes_.at(i).IsFullyDefined()) { - Tensor* out = nullptr; - TensorShape output_shape; - output_shapes_.at(i).AsTensorShape(&output_shape); - output_shape.InsertDim(0, batch_size); - OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), - done); - } - } - - SetRunOptions(ctx, &opts_, false); + FunctionLibraryRuntime::Options opts; + SetRunOptions(ctx, &opts, false); // Run loop StatusCallback callback = std::bind( - [](OpKernelContext* ctx, std::vector<Tensor>* args, - std::vector<TensorShape>* arg_shapes, - std::vector<PartialTensorShape>* output_shapes, OpOutputList* output, - mutex* mu, DoneCallback& done, const Status& status) { - delete args; - delete arg_shapes; - delete output; - delete output_shapes; - delete mu; + [](OpKernelContext* ctx, ComputeOptions* compute_opts, + DoneCallback& done, const Status& status) { + delete compute_opts; ctx->SetStatus(status); done(); }, - ctx, args, arg_shapes, output_shapes, output, mu, std::move(done), - std::placeholders::_1); + ctx, compute_opts, std::move(done), std::placeholders::_1); auto* refcounted = new ReffedStatusCallback(std::move(callback)); - for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) { - // Start from i = 1 because refcounted is initialized with refcount = 1 - refcounted->Ref(); - } + CancellationManager* parent_mgr = ctx->cancellation_manager(); - for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) { - auto* call_frame = new MapFunctionCallFrame( - *args, *arg_shapes, output_shapes, mu, output, this, i, - static_cast<size_t>(batch_size)); + for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) { + // We use a different cancellation manager each time the function is run + // to avoid the race condition between a function run error and other + // functions being cancelled as a result. CancellationManager* c_mgr = new CancellationManager; - opts_.cancellation_manager = c_mgr; - ctx->function_library()->Run( - opts_, func_handle_, call_frame, - [call_frame, refcounted, c_mgr](const Status& func_status) { - delete call_frame; - delete c_mgr; - refcounted->UpdateStatus(func_status); - refcounted->Unref(); - }); + CancellationToken token = parent_mgr->get_cancellation_token(); + const bool success = parent_mgr->RegisterCallback( + token, [c_mgr]() { c_mgr->StartCancel(); }); + + opts.cancellation_manager = c_mgr; + if (!success) { + delete c_mgr; + refcounted->UpdateStatus(errors::Cancelled( + "MapDefunOp functions cancelled because parent graph cancelled")); + break; + } + + auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i); + + refcounted->Ref(); + ctx->function_library()->Run(opts, func_handle_, call_frame, + [call_frame, refcounted, c_mgr, parent_mgr, + token](const Status& func_status) { + parent_mgr->DeregisterCallback(token); + delete c_mgr; + delete call_frame; + refcounted->UpdateStatus(func_status); + refcounted->Unref(); + }); } + + // Unref 1 because refcounted is initialized with refcount = 1 + refcounted->Unref(); } private: FunctionLibraryRuntime::Handle func_handle_; - FunctionLibraryRuntime::Options opts_; std::vector<PartialTensorShape> output_shapes_; + struct ComputeOptions { + // These vary per MapDefunOp::ComputeAsync call, but must persist until + // all calls to the function are complete. This struct also encapsulates + // all the components that need to be passed to each MapFunctionCallFrame. + + const std::vector<Tensor> args; + const std::vector<TensorShape> arg_shapes; + const int64 batch_size; + + // Output of a compute call + std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu); + OpOutputList output GUARDED_BY(mu); + mutex mu; + + // Create a copy of output_shapes because every `Compute` may expect a + // different output shape. + ComputeOptions(std::vector<Tensor> args, + std::vector<TensorShape> arg_shapes, int64 batch_size, + const std::vector<PartialTensorShape>& output_shapes_attr) + : args(std::move(args)), + arg_shapes(std::move(arg_shapes)), + batch_size(batch_size), + output_shapes(output_shapes_attr) {} + }; + + // Get inputs to Compute and check that they are valid. + Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { + int64 batch_size = + ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } + } + + std::vector<Tensor> args; + std::vector<TensorShape> arg_shapes; + args.reserve(ctx->num_inputs()); + arg_shapes.reserve(ctx->num_inputs()); + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + args.push_back(ctx->input(i)); + arg_shapes.push_back(ctx->input(i).shape()); + arg_shapes.at(i).RemoveDim(0); + } + + *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), + batch_size, output_shapes_); + return Status::OK(); + } + + Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { + mutex_lock l(opts->mu); + TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output)); + + for (size_t i = 0; i < output_types().size(); ++i) { + if (output_shapes_.at(i).IsFullyDefined()) { + Tensor* out = nullptr; + TensorShape output_shape; + output_shapes_.at(i).AsTensorShape(&output_shape); + output_shape.InsertDim(0, opts->batch_size); + TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out)); + } + } + return Status::OK(); + } + class MapFunctionCallFrame : public CallFrameInterface { public: - MapFunctionCallFrame(const std::vector<Tensor>& args, - const std::vector<TensorShape>& arg_shapes, - std::vector<PartialTensorShape>* output_shapes, - mutex* output_shapes_mutex, OpOutputList* output, - OpKernel* kernel, size_t iter, size_t batch_size) - : args_(args), - arg_shapes_(arg_shapes), - output_shapes_(output_shapes), - output_shapes_mutex_(output_shapes_mutex), - output_(output), - kernel_(kernel), - iter_(iter), - batch_size_(batch_size) {} + MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel, + size_t iter) + : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {} ~MapFunctionCallFrame() override {} - size_t num_args() const override { return args_.size(); } + size_t num_args() const override { return compute_opts_->args.size(); } + size_t num_retvals() const override { return static_cast<size_t>(kernel_->num_outputs()); } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= args_.size()) { + if (index < 0 || index >= compute_opts_->args.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } - bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1), - arg_shapes_.at(index)); + bool result = + val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); } else if (!val->IsAligned()) { @@ -217,36 +265,34 @@ class MapDefunOp : public AsyncOpKernel { index); } { // Locking scope - mutex_lock l(*output_shapes_mutex_); - if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) { + mutex_lock l(compute_opts_->mu); + if (!compute_opts_->output_shapes.at(index).IsCompatibleWith( + val.shape())) { return errors::InvalidArgument( "Mismatch in function retval shape, ", val.shape(), - ", and expected output shape,", - output_shapes_->at(index).DebugString(), "."); + ", and expected output shape, ", + compute_opts_->output_shapes.at(index).DebugString(), "."); } - if (!output_shapes_->at(index).IsFullyDefined()) { + if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) { // Given val, we have new information about the output shape at // this index. Store the shape and allocate the output accordingly. - output_shapes_->at(index) = val.shape(); + compute_opts_->output_shapes.at(index) = val.shape(); Tensor* out = nullptr; TensorShape actual_shape = val.shape(); - actual_shape.InsertDim(0, batch_size_); - TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out)); + actual_shape.InsertDim(0, compute_opts_->batch_size); + TF_RETURN_IF_ERROR( + compute_opts_->output.allocate(index, actual_shape, &out)); } + return batch_util::CopyElementToSlice( + val, (compute_opts_->output)[index], iter_); } - return batch_util::CopyElementToSlice(val, (*output_)[index], iter_); } private: - const std::vector<Tensor>& args_; - const std::vector<TensorShape>& arg_shapes_; - std::vector<PartialTensorShape>* output_shapes_; - mutex* output_shapes_mutex_; - OpOutputList* output_; + ComputeOptions* const compute_opts_; // Not owned const OpKernel* kernel_; const size_t iter_; - const size_t batch_size_; }; }; |