diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 13:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:12:57 -0700 |
commit | 7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch) | |
tree | 84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/core/kernels | |
parent | 074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff) |
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/data/map_defun_op.cc | 68 |
1 files changed, 31 insertions, 37 deletions
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 6657f2b2b3..705b0393de 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} - Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { - // Validates inputs and gets the size of their leading dimension. - *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { - return errors::InvalidArgument( - "All inputs must have rank at least 1. Input ", i, - " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != *batch_size) { - return errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size); - } - } - return Status::OK(); - } - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { ComputeOptions* compute_opts = nullptr; @@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel { // all calls to the function are complete. This struct also encapsulates // all the components that need to be passed to each MapFunctionCallFrame. - const std::vector<Tensor> args; + OpInputList args; const std::vector<TensorShape> arg_shapes; + OpInputList captured_inputs; const int64 batch_size; // Output of a compute call @@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel { // Create a copy of output_shapes because every `Compute` may expect a // different output shape. - ComputeOptions(std::vector<Tensor> args, + ComputeOptions(OpInputList args, OpInputList captured_inputs, std::vector<TensorShape> arg_shapes, int64 batch_size, const std::vector<PartialTensorShape>& output_shapes_attr) - : args(std::move(args)), + : args(args), arg_shapes(std::move(arg_shapes)), + captured_inputs(captured_inputs), batch_size(batch_size), output_shapes(output_shapes_attr) {} }; // Get inputs to Compute and check that they are valid. Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { - int64 batch_size = - ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + OpInputList arguments; + TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments)); + OpInputList captured_inputs; + TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs)); + + int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { + for (size_t i = 0; i < arguments.size(); ++i) { + if (arguments[i].dims() == 0) { return errors::InvalidArgument( "All inputs must have rank at least 1. Input ", i, " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != batch_size) { + } else if (arguments[i].dim_size(0) != batch_size) { return errors::InvalidArgument( "All inputs must have the same dimension 0. Input ", i, " has leading dimension ", ctx->input(i).dim_size(0), @@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel { } } - std::vector<Tensor> args; std::vector<TensorShape> arg_shapes; - args.reserve(ctx->num_inputs()); - arg_shapes.reserve(ctx->num_inputs()); + arg_shapes.reserve(arguments.size()); - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args.push_back(ctx->input(i)); - arg_shapes.push_back(ctx->input(i).shape()); + for (size_t i = 0; i < arguments.size(); ++i) { + arg_shapes.push_back(arguments[i].shape()); arg_shapes.at(i).RemoveDim(0); } - *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), - batch_size, output_shapes_); + *compute_opts = + new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes), + batch_size, output_shapes_); return Status::OK(); } @@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel { } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= compute_opts_->args.size()) { + if (index < 0 || index >= compute_opts_->args.size() + + compute_opts_->captured_inputs.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } + + if (index >= compute_opts_->args.size()) { + // The function is calling for a captured input + *val = + compute_opts_->captured_inputs[index - compute_opts_->args.size()]; + return Status::OK(); + } + bool result = - val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1), compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); @@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel { // Ensure alignment *val = tensor::DeepCopy(*val); } - return Status::OK(); } |