aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 13:01:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:12:57 -0700
commit7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch)
tree84087a64563d10c3390f991c6263c7fa2cc65b11
parent074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff)
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt23
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc3
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc68
-rw-r--r--tensorflow/core/ops/dataset_ops.cc11
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py12
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py8
7 files changed, 77 insertions, 49 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
index 4433693759..d158f4b502 100644
--- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
@@ -4,16 +4,23 @@ op {
in_arg {
name: "arguments"
description: <<END
- A list of tensors whose types are Targuments, corresponding to the inputs the
- function should be mapped over.
+ A list of tensors whose types are `Targuments`, corresponding to the inputs
+ the function should be mapped over.
+END
+ }
+ in_arg {
+ name: "captured_inputs"
+ description: <<END
+ A list of tensors whose types are `Tcaptured`, corresponding to the captured
+ inputs of the defun.
END
}
out_arg {
name: "output"
description: <<END
- A list of output tensors whose types are output_types and whose dimensions 0
- are the same as the dimensions 0 of the tensors in arguments, and whose
- remaining dimensions correspond to those in output_shapes.
+ A list of output tensors whose types are `output_types` and whose dimensions
+ 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
+ remaining dimensions correspond to those in `output_shapes`.
END
}
attr {
@@ -21,6 +28,10 @@ END
description: "A list of types."
}
attr {
+ name: "Tcaptured"
+ description: "A list of types."
+ }
+ attr {
name: "output_types"
description: "A list of types."
}
@@ -29,6 +40,6 @@ END
description: "A list of shapes."
}
summary: <<END
- Maps a function on the list of tensors unpacked from inputs on dimension 0.
+ Maps a function on the list of tensors unpacked from arguments on dimension 0.
END
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index ba521e79bc..a9254ed58b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index a958d706c1..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
@@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 6657f2b2b3..705b0393de 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
- Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
- // Validates inputs and gets the size of their leading dimension.
- *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
- return errors::InvalidArgument(
- "All inputs must have rank at least 1. Input ", i,
- " has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != *batch_size) {
- return errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size);
- }
- }
- return Status::OK();
- }
-
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
@@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel {
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
- const std::vector<Tensor> args;
+ OpInputList args;
const std::vector<TensorShape> arg_shapes;
+ OpInputList captured_inputs;
const int64 batch_size;
// Output of a compute call
@@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel {
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
- ComputeOptions(std::vector<Tensor> args,
+ ComputeOptions(OpInputList args, OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr)
- : args(std::move(args)),
+ : args(args),
arg_shapes(std::move(arg_shapes)),
+ captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {}
};
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
- int64 batch_size =
- ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ OpInputList arguments;
+ TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
+ OpInputList captured_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
+
+ int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != batch_size) {
+ } else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
@@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel {
}
}
- std::vector<Tensor> args;
std::vector<TensorShape> arg_shapes;
- args.reserve(ctx->num_inputs());
- arg_shapes.reserve(ctx->num_inputs());
+ arg_shapes.reserve(arguments.size());
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args.push_back(ctx->input(i));
- arg_shapes.push_back(ctx->input(i).shape());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
- *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
- batch_size, output_shapes_);
+ *compute_opts =
+ new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
+ batch_size, output_shapes_);
return Status::OK();
}
@@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel {
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= compute_opts_->args.size()) {
+ if (index < 0 || index >= compute_opts_->args.size() +
+ compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
+
+ if (index >= compute_opts_->args.size()) {
+ // The function is calling for a captured input
+ *val =
+ compute_opts_->captured_inputs[index - compute_opts_->args.size()];
+ return Status::OK();
+ }
+
bool result =
- val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
@@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
-
return Status::OK();
}
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 889a6a4640..ec22eee874 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset")
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
+ .Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
+ .Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ DataTypeVector t_args;
+ TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@@ -918,10 +922,11 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
- for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ for (size_t i = 0; i < t_args.size(); ++i) {
if (c->Rank(c->input(i)) == 0) {
return errors::InvalidArgument(
- "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ "Arguments must have rank at least 1. Input ", i,
+ " has rank of 0.");
}
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
@@ -929,7 +934,7 @@ REGISTER_OP("MapDefun")
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
- "Inputs must have the same dimension 0.");
+ "Arguments must have the same dimension 0.");
}
}
}
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 612ee332c4..ae9dedb0ab 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase):
sess.close()
thread.join()
+ def testMapDefunWithCapturedInputs(self):
+ c = constant_op.constant(2)
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x + c
+
+ x = constant_op.constant([1, 2, 3, 4])
+ map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
+ expected = x + c
+ self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
+
class MapDefunBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
index 3d0d0993c9..3ac1158d8b 100644
--- a/tensorflow/python/data/experimental/ops/map_defun.py
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
@@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
if not isinstance(elems, list):
raise ValueError("`elems` must be a list of tensors.")
if not isinstance(output_dtypes, list):
- raise ValueError("`output_dtypes` must be a list of tensors.")
+ raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.")
if not isinstance(output_shapes, list):
- raise ValueError("`output_shapes` must be a list of tensors.")
+ raise ValueError("`output_shapes` must be a list of `tf.TensorShape` "
+ "objects.")
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
+ return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes,
+ output_shapes, fn)