aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_defun_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_defun_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc234
1 files changed, 140 insertions, 94 deletions
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index b87d61ee44..6657f2b2b3 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -81,119 +81,167 @@ class MapDefunOp : public AsyncOpKernel {
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size;
- OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+ ComputeOptions* compute_opts = nullptr;
- // Inputs
- auto* args = new std::vector<Tensor>;
- auto* arg_shapes = new std::vector<TensorShape>;
+ OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
- // Create a copy because every `Compute` may have different output shapes.
- auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
- arg_shapes->reserve(ctx->num_inputs());
- args->reserve(ctx->num_inputs());
+ Status s = SetupOutputs(ctx, compute_opts);
+ if (!s.ok()) delete compute_opts;
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
- auto* mu = new mutex;
-
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args->push_back(ctx->input(i));
- arg_shapes->push_back(ctx->input(i).shape());
- arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- }
-
- // Outputs
- auto* output = new OpOutputList;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
-
- for (size_t i = 0; i < output_types().size(); ++i) {
- if (output_shapes_.at(i).IsFullyDefined()) {
- Tensor* out = nullptr;
- TensorShape output_shape;
- output_shapes_.at(i).AsTensorShape(&output_shape);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
- done);
- }
- }
-
- SetRunOptions(ctx, &opts_, false);
+ FunctionLibraryRuntime::Options opts;
+ SetRunOptions(ctx, &opts, false);
// Run loop
StatusCallback callback = std::bind(
- [](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes,
- std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
- mutex* mu, DoneCallback& done, const Status& status) {
- delete args;
- delete arg_shapes;
- delete output;
- delete output_shapes;
- delete mu;
+ [](OpKernelContext* ctx, ComputeOptions* compute_opts,
+ DoneCallback& done, const Status& status) {
+ delete compute_opts;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
- std::placeholders::_1);
+ ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
- for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
- // Start from i = 1 because refcounted is initialized with refcount = 1
- refcounted->Ref();
- }
+ CancellationManager* parent_mgr = ctx->cancellation_manager();
- for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame = new MapFunctionCallFrame(
- *args, *arg_shapes, output_shapes, mu, output, this, i,
- static_cast<size_t>(batch_size));
+ for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+ // We use a different cancellation manager each time the function is run
+ // to avoid the race condition between a function run error and other
+ // functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager;
- opts_.cancellation_manager = c_mgr;
- ctx->function_library()->Run(
- opts_, func_handle_, call_frame,
- [call_frame, refcounted, c_mgr](const Status& func_status) {
- delete call_frame;
- delete c_mgr;
- refcounted->UpdateStatus(func_status);
- refcounted->Unref();
- });
+ CancellationToken token = parent_mgr->get_cancellation_token();
+ const bool success = parent_mgr->RegisterCallback(
+ token, [c_mgr]() { c_mgr->StartCancel(); });
+
+ opts.cancellation_manager = c_mgr;
+ if (!success) {
+ delete c_mgr;
+ refcounted->UpdateStatus(errors::Cancelled(
+ "MapDefunOp functions cancelled because parent graph cancelled"));
+ break;
+ }
+
+ auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
+ refcounted->Ref();
+ ctx->function_library()->Run(opts, func_handle_, call_frame,
+ [call_frame, refcounted, c_mgr, parent_mgr,
+ token](const Status& func_status) {
+ parent_mgr->DeregisterCallback(token);
+ delete c_mgr;
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
}
+
+ // Unref 1 because refcounted is initialized with refcount = 1
+ refcounted->Unref();
}
private:
FunctionLibraryRuntime::Handle func_handle_;
- FunctionLibraryRuntime::Options opts_;
std::vector<PartialTensorShape> output_shapes_;
+ struct ComputeOptions {
+ // These vary per MapDefunOp::ComputeAsync call, but must persist until
+ // all calls to the function are complete. This struct also encapsulates
+ // all the components that need to be passed to each MapFunctionCallFrame.
+
+ const std::vector<Tensor> args;
+ const std::vector<TensorShape> arg_shapes;
+ const int64 batch_size;
+
+ // Output of a compute call
+ std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+ OpOutputList output GUARDED_BY(mu);
+ mutex mu;
+
+ // Create a copy of output_shapes because every `Compute` may expect a
+ // different output shape.
+ ComputeOptions(std::vector<Tensor> args,
+ std::vector<TensorShape> arg_shapes, int64 batch_size,
+ const std::vector<PartialTensorShape>& output_shapes_attr)
+ : args(std::move(args)),
+ arg_shapes(std::move(arg_shapes)),
+ batch_size(batch_size),
+ output_shapes(output_shapes_attr) {}
+ };
+
+ // Get inputs to Compute and check that they are valid.
+ Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+ int64 batch_size =
+ ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+
+ std::vector<Tensor> args;
+ std::vector<TensorShape> arg_shapes;
+ args.reserve(ctx->num_inputs());
+ arg_shapes.reserve(ctx->num_inputs());
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ arg_shapes.push_back(ctx->input(i).shape());
+ arg_shapes.at(i).RemoveDim(0);
+ }
+
+ *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
+ batch_size, output_shapes_);
+ return Status::OK();
+ }
+
+ Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+ mutex_lock l(opts->mu);
+ TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, opts->batch_size);
+ TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+ }
+ }
+ return Status::OK();
+ }
+
class MapFunctionCallFrame : public CallFrameInterface {
public:
- MapFunctionCallFrame(const std::vector<Tensor>& args,
- const std::vector<TensorShape>& arg_shapes,
- std::vector<PartialTensorShape>* output_shapes,
- mutex* output_shapes_mutex, OpOutputList* output,
- OpKernel* kernel, size_t iter, size_t batch_size)
- : args_(args),
- arg_shapes_(arg_shapes),
- output_shapes_(output_shapes),
- output_shapes_mutex_(output_shapes_mutex),
- output_(output),
- kernel_(kernel),
- iter_(iter),
- batch_size_(batch_size) {}
+ MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+ size_t iter)
+ : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
- size_t num_args() const override { return args_.size(); }
+ size_t num_args() const override { return compute_opts_->args.size(); }
+
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= args_.size()) {
+ if (index < 0 || index >= compute_opts_->args.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
- bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
- arg_shapes_.at(index));
+ bool result =
+ val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
@@ -217,36 +265,34 @@ class MapDefunOp : public AsyncOpKernel {
index);
}
{ // Locking scope
- mutex_lock l(*output_shapes_mutex_);
- if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ mutex_lock l(compute_opts_->mu);
+ if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+ val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
- ", and expected output shape,",
- output_shapes_->at(index).DebugString(), ".");
+ ", and expected output shape, ",
+ compute_opts_->output_shapes.at(index).DebugString(), ".");
}
- if (!output_shapes_->at(index).IsFullyDefined()) {
+ if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
- output_shapes_->at(index) = val.shape();
+ compute_opts_->output_shapes.at(index) = val.shape();
Tensor* out = nullptr;
TensorShape actual_shape = val.shape();
- actual_shape.InsertDim(0, batch_size_);
- TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ actual_shape.InsertDim(0, compute_opts_->batch_size);
+ TF_RETURN_IF_ERROR(
+ compute_opts_->output.allocate(index, actual_shape, &out));
}
+ return batch_util::CopyElementToSlice(
+ val, (compute_opts_->output)[index], iter_);
}
- return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
- const std::vector<Tensor>& args_;
- const std::vector<TensorShape>& arg_shapes_;
- std::vector<PartialTensorShape>* output_shapes_;
- mutex* output_shapes_mutex_;
- OpOutputList* output_;
+ ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
- const size_t batch_size_;
};
};