diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 13:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:12:57 -0700 |
commit | 7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (patch) | |
tree | 84087a64563d10c3390f991c6263c7fa2cc65b11 /tensorflow/core/grappler | |
parent | 074ff471fefbcf3bfd49914ad80bd9f9751df363 (diff) |
[tf.data] Add a notion of `captured args` to MapDefun
PiperOrigin-RevId: 215788485
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_vectorization.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc | 3 |
2 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index ba521e79bc..a9254ed58b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, map_defun_node->add_input(input.name()); } (*map_defun_node->mutable_attr())["Targuments"] = t_args; + AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index a958d706c1..a6020e36bb 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, func.set_name(function_name); NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node); graph_transforms::SetNodeAttr("output_types", output_types, node); graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); graph_transforms::SetNodeAttr("f", func, node); @@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; + Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); + LOG(ERROR) << s; EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); |