diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 13:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:12:57 -0700 |
commit | 7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch) | |
tree | 84087a64563d10c3390f991c6263c7fa2cc65b11 | |
parent | 074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff) |
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
7 files changed, 77 insertions, 49 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt index 4433693759..d158f4b502 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt @@ -4,16 +4,23 @@ op { in_arg { name: "arguments" description: <<END - A list of tensors whose types are Targuments, corresponding to the inputs the - function should be mapped over. + A list of tensors whose types are `Targuments`, corresponding to the inputs + the function should be mapped over. +END + } + in_arg { + name: "captured_inputs" + description: <<END + A list of tensors whose types are `Tcaptured`, corresponding to the captured + inputs of the defun. END } out_arg { name: "output" description: <<END - A list of output tensors whose types are output_types and whose dimensions 0 - are the same as the dimensions 0 of the tensors in arguments, and whose - remaining dimensions correspond to those in output_shapes. + A list of output tensors whose types are `output_types` and whose dimensions + 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose + remaining dimensions correspond to those in `output_shapes`. END } attr { @@ -21,6 +28,10 @@ END description: "A list of types." } attr { + name: "Tcaptured" + description: "A list of types." + } + attr { name: "output_types" description: "A list of types." } @@ -29,6 +40,6 @@ END description: "A list of shapes." } summary: <<END - Maps a function on the list of tensors unpacked from inputs on dimension 0. + Maps a function on the list of tensors unpacked from arguments on dimension 0. END } diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index ba521e79bc..a9254ed58b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, map_defun_node->add_input(input.name()); } (*map_defun_node->mutable_attr())["Targuments"] = t_args; + AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index a958d706c1..a6020e36bb 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, func.set_name(function_name); NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node); graph_transforms::SetNodeAttr("output_types", output_types, node); graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); graph_transforms::SetNodeAttr("f", func, node); @@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; + Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); + LOG(ERROR) << s; EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 6657f2b2b3..705b0393de 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} - Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { - // Validates inputs and gets the size of their leading dimension. - *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { - return errors::InvalidArgument( - "All inputs must have rank at least 1. Input ", i, - " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != *batch_size) { - return errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size); - } - } - return Status::OK(); - } - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { ComputeOptions* compute_opts = nullptr; @@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel { // all calls to the function are complete. This struct also encapsulates // all the components that need to be passed to each MapFunctionCallFrame. - const std::vector<Tensor> args; + OpInputList args; const std::vector<TensorShape> arg_shapes; + OpInputList captured_inputs; const int64 batch_size; // Output of a compute call @@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel { // Create a copy of output_shapes because every `Compute` may expect a // different output shape. - ComputeOptions(std::vector<Tensor> args, + ComputeOptions(OpInputList args, OpInputList captured_inputs, std::vector<TensorShape> arg_shapes, int64 batch_size, const std::vector<PartialTensorShape>& output_shapes_attr) - : args(std::move(args)), + : args(args), arg_shapes(std::move(arg_shapes)), + captured_inputs(captured_inputs), batch_size(batch_size), output_shapes(output_shapes_attr) {} }; // Get inputs to Compute and check that they are valid. Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { - int64 batch_size = - ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + OpInputList arguments; + TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments)); + OpInputList captured_inputs; + TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs)); + + int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { + for (size_t i = 0; i < arguments.size(); ++i) { + if (arguments[i].dims() == 0) { return errors::InvalidArgument( "All inputs must have rank at least 1. Input ", i, " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != batch_size) { + } else if (arguments[i].dim_size(0) != batch_size) { return errors::InvalidArgument( "All inputs must have the same dimension 0. Input ", i, " has leading dimension ", ctx->input(i).dim_size(0), @@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel { } } - std::vector<Tensor> args; std::vector<TensorShape> arg_shapes; - args.reserve(ctx->num_inputs()); - arg_shapes.reserve(ctx->num_inputs()); + arg_shapes.reserve(arguments.size()); - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args.push_back(ctx->input(i)); - arg_shapes.push_back(ctx->input(i).shape()); + for (size_t i = 0; i < arguments.size(); ++i) { + arg_shapes.push_back(arguments[i].shape()); arg_shapes.at(i).RemoveDim(0); } - *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), - batch_size, output_shapes_); + *compute_opts = + new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes), + batch_size, output_shapes_); return Status::OK(); } @@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel { } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= compute_opts_->args.size()) { + if (index < 0 || index >= compute_opts_->args.size() + + compute_opts_->captured_inputs.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } + + if (index >= compute_opts_->args.size()) { + // The function is calling for a captured input + *val = + compute_opts_->captured_inputs[index - compute_opts_->args.size()]; + return Status::OK(); + } + bool result = - val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1), compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); @@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel { // Ensure alignment *val = tensor::DeepCopy(*val); } - return Status::OK(); } diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 889a6a4640..ec22eee874 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset") REGISTER_OP("MapDefun") .Input("arguments: Targuments") + .Input("captured_inputs: Tcaptured") .Output("output: output_types") .Attr("Targuments: list(type) >= 1") + .Attr("Tcaptured: list(type) >= 0 = []") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + DataTypeVector t_args; + TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( "`output_shapes` must be the same length as `output_types` (", @@ -918,10 +922,11 @@ REGISTER_OP("MapDefun") } int64 dim_zero = -1; - for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + for (size_t i = 0; i < t_args.size(); ++i) { if (c->Rank(c->input(i)) == 0) { return errors::InvalidArgument( - "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + "Arguments must have rank at least 1. Input ", i, + " has rank of 0."); } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { @@ -929,7 +934,7 @@ REGISTER_OP("MapDefun") dim_zero = c->Value(dim_handle); } else if (c->Value(dim_handle) != dim_zero) { return errors::InvalidArgument( - "Inputs must have the same dimension 0."); + "Arguments must have the same dimension 0."); } } } diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py index 612ee332c4..ae9dedb0ab 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py @@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase): sess.close() thread.join() + def testMapDefunWithCapturedInputs(self): + c = constant_op.constant(2) + + @function.Defun(dtypes.int32) + def fn(x): + return x + c + + x = constant_op.constant([1, 2, 3, 4]) + map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0] + expected = x + c + self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op)) + class MapDefunBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py index 3d0d0993c9..3ac1158d8b 100644 --- a/tensorflow/python/data/experimental/ops/map_defun.py +++ b/tensorflow/python/data/experimental/ops/map_defun.py @@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes): if not isinstance(elems, list): raise ValueError("`elems` must be a list of tensors.") if not isinstance(output_dtypes, list): - raise ValueError("`output_dtypes` must be a list of tensors.") + raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.") if not isinstance(output_shapes, list): - raise ValueError("`output_shapes` must be a list of tensors.") + raise ValueError("`output_shapes` must be a list of `tf.TensorShape` " + "objects.") elems = [ops.convert_to_tensor(e) for e in elems] output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] - return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) + return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes, + output_shapes, fn) |