aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_defun_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_defun_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc98
1 files changed, 77 insertions, 21 deletions
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 3c562fc7f3..b87d61ee44 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
@@ -60,26 +62,43 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
+ Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
+ // Validates inputs and gets the size of their leading dimension.
+ *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != *batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+ return Status::OK();
+ }
+
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size = ctx->input(0).dim_size(0);
+ int64 batch_size;
+ OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+
// Inputs
auto* args = new std::vector<Tensor>;
auto* arg_shapes = new std::vector<TensorShape>;
+
+ // Create a copy because every `Compute` may have different output shapes.
+ auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
arg_shapes->reserve(ctx->num_inputs());
args->reserve(ctx->num_inputs());
+ auto* mu = new mutex;
+
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
args->push_back(ctx->input(i));
arg_shapes->push_back(ctx->input(i).shape());
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- OP_REQUIRES_ASYNC(
- ctx, batch_size == ctx->input(i).dim_size(0),
- errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size,
- "."),
- done);
}
// Outputs
@@ -87,10 +106,14 @@ class MapDefunOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
for (size_t i = 0; i < output_types().size(); ++i) {
- Tensor* out = nullptr;
- TensorShape output_shape = output_shapes_.at(i);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, batch_size);
+ OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
+ done);
+ }
}
SetRunOptions(ctx, &opts_, false);
@@ -98,15 +121,19 @@ class MapDefunOp : public AsyncOpKernel {
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes, OpOutputList* output,
- DoneCallback& done, const Status& status) {
+ std::vector<TensorShape>* arg_shapes,
+ std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
+ mutex* mu, DoneCallback& done, const Status& status) {
delete args;
delete arg_shapes;
delete output;
+ delete output_shapes;
+ delete mu;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+ ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
+ std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
@@ -114,9 +141,11 @@ class MapDefunOp : public AsyncOpKernel {
// Start from i = 1 because refcounted is initialized with refcount = 1
refcounted->Ref();
}
+
for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame =
- new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
+ auto* call_frame = new MapFunctionCallFrame(
+ *args, *arg_shapes, output_shapes, mu, output, this, i,
+ static_cast<size_t>(batch_size));
CancellationManager* c_mgr = new CancellationManager;
opts_.cancellation_manager = c_mgr;
ctx->function_library()->Run(
@@ -133,18 +162,23 @@ class MapDefunOp : public AsyncOpKernel {
private:
FunctionLibraryRuntime::Handle func_handle_;
FunctionLibraryRuntime::Options opts_;
- std::vector<TensorShape> output_shapes_;
+ std::vector<PartialTensorShape> output_shapes_;
class MapFunctionCallFrame : public CallFrameInterface {
public:
MapFunctionCallFrame(const std::vector<Tensor>& args,
const std::vector<TensorShape>& arg_shapes,
- OpOutputList* output, OpKernel* kernel, size_t iter)
+ std::vector<PartialTensorShape>* output_shapes,
+ mutex* output_shapes_mutex, OpOutputList* output,
+ OpKernel* kernel, size_t iter, size_t batch_size)
: args_(args),
arg_shapes_(arg_shapes),
+ output_shapes_(output_shapes),
+ output_shapes_mutex_(output_shapes_mutex),
output_(output),
kernel_(kernel),
- iter_(iter) {}
+ iter_(iter),
+ batch_size_(batch_size) {}
~MapFunctionCallFrame() override {}
@@ -182,15 +216,37 @@ class MapDefunOp : public AsyncOpKernel {
"output: ",
index);
}
+ { // Locking scope
+ mutex_lock l(*output_shapes_mutex_);
+ if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ return errors::InvalidArgument(
+ "Mismatch in function retval shape, ", val.shape(),
+ ", and expected output shape,",
+ output_shapes_->at(index).DebugString(), ".");
+ }
+ if (!output_shapes_->at(index).IsFullyDefined()) {
+ // Given val, we have new information about the output shape at
+ // this index. Store the shape and allocate the output accordingly.
+ output_shapes_->at(index) = val.shape();
+
+ Tensor* out = nullptr;
+ TensorShape actual_shape = val.shape();
+ actual_shape.InsertDim(0, batch_size_);
+ TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ }
+ }
return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
const std::vector<Tensor>& args_;
const std::vector<TensorShape>& arg_shapes_;
+ std::vector<PartialTensorShape>* output_shapes_;
+ mutex* output_shapes_mutex_;
OpOutputList* output_;
const OpKernel* kernel_;
const size_t iter_;
+ const size_t batch_size_;
};
};