aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 13:01:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:12:57 -0700
commit7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch)
tree84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/core/grappler
parent074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff)
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc3
2 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index ba521e79bc..a9254ed58b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index a958d706c1..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
@@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));