diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-09 10:40:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 10:45:19 -0700 |
commit | 11f32ebbdcd4eaf5e9e09fe27571e26ec0bd9dd8 (patch) | |
tree | 032a600b39f926c9ec6ab62625b5a5fd03f20e87 | |
parent | 1b4402137a76c8085c160edfcc0c3be3cfa8fa3a (diff) |
[tf.data vectorization] Handle captured inputs in MapVectorization optimization
PiperOrigin-RevId: 216381943
3 files changed, 53 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index a9254ed58b..0576d075c2 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -60,14 +60,24 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, graph_utils::CopyAttribute(k, map_node, map_defun_node); } + // Note that the inputs to the function are either regular arguments (for + // which the function is mapped across their 0th dimension) or captured inputs + // (for which the function takes the argument wholesale). We can infer + // the split between these arguments from the `map_node`'s attrs. + // The Targuments attr on `map_node` corresponds to a list of types of + // MapDataset's captured inputs. + auto t_captured = map_node.attr().at("Targuments"); + // Get types of input arguments from original map function - AttrValue t_args; + DataTypeVector t_args; // Regular arguments for (const auto& input : vectorized_func->signature().input_arg()) { - t_args.mutable_list()->add_type(input.type()); + t_args.push_back(input.type()); map_defun_node->add_input(input.name()); } - (*map_defun_node->mutable_attr())["Targuments"] = t_args; - AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); + // Erase the captured arguments from Targuments + t_args.erase(t_args.end() - t_captured.list().type_size(), t_args.end()); + AddNodeAttr("Targuments", t_args, map_defun_node); + AddNodeAttr("Tcaptured", t_captured, map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); @@ -96,7 +106,9 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, *vectorized_func, map_defun_node, library, &result); if (!s.ok()) { - LOG(ERROR) << "VectorizeMapDefun failed: " << s; + LOG(WARNING) << "VectorizeMapDefun failed. The function will only be " + "naively vectorized with MapDefun. Reason: " + << s; return vectorized_func; } return result; @@ -129,10 +141,6 @@ bool IsStatefulFn(const FunctionLibraryDefinition& library, return false; } -bool HasCapturedInputs(const NodeDef& map_node) { - return map_node.attr().at("Targuments").list().type_size() > 0; -} - NodeDef MakeNewBatchNode(const NodeDef& old_batch_node, const NodeDef& input_node, const FunctionDef& vectorized_func, @@ -239,15 +247,12 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item, // Check that this is a valid optimization. if (!IsOutputShapesFullyDefined(*input_node) || !IsOutputShapesFullyDefined(*map_node) || - IsStatefulFn(function_library, *orig_func) || - HasCapturedInputs(*map_node)) { + IsStatefulFn(function_library, *orig_func)) { // 1. If any of the inputs have an unknown shape, don't optimize, since // inputs might not be batchable. // 2. If any of the map func outputs have an unknown shape, don't // optimize, so that batching errors surface as before. // 3. If the function is stateful, don't vectorize it. - // 4. TODO(rachelim): Make this work for MapDataset with captured inputs - // by tiling inputs or modifying the signature of MapDefun. continue; } diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index ba857ab5d9..d977ff3198 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -253,8 +253,13 @@ Status Vectorization::AddConversionMapping(Node* op_node) { } } - TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), - std::move(inputs), &outputs)); + Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs); + if (!s.ok()) { + VLOG(2) << "Vectorizer for op \"" << op_node->type_string() + << "\" failed with error: " << s; + return s; + } if (op_node->num_outputs() != outputs.size()) { return errors::Internal( @@ -481,17 +486,37 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, } Status Vectorization::AddArgNodeMappings() { - for (auto arg_node : map_defun_fn_->arg_nodes) { + // Note that inputs to map_defun_fn_ are either regular arguments (for which + // the operations are mapped across their 0th dimension) or captured inputs + // (for which the operations apply to the argument wholesale). + int num_args = + map_defun_node_->attrs().Find("Targuments")->list().type_size(); + + auto add_conversion = [this](Node* arg_node, bool stacked) { Node* input_node; TF_RETURN_IF_ERROR(map_defun_node_->input_node( arg_node->attrs().Find("index")->i(), &input_node)); - conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}}); // Control inputs conversion_map_.insert({{arg_node, Graph::kControlSlot}, - {input_node, Graph::kControlSlot, true}}); + {input_node, Graph::kControlSlot, stacked}}); + + return Status::OK(); + }; + + // Regular arguments + for (int i = 0; i < num_args; ++i) { + TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true)); + } + + // Captured inputs. These are applied (without slicing) to every iteration of + // the map function, hence are mapped to unstacked nodes. + for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) { + TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false)); } + return Status::OK(); } diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 971a2d94b9..803ff87924 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -105,15 +105,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def testOptimizationWithCapturedInputs(self): # Tests that vectorization works with captured inputs + y = constant_op.constant(1, shape=(2,)) + z = constant_op.constant(2, shape=(2,)) + def map_fn(x): - return x + y + return x, y, z - y = constant_op.constant(1, shape=(2,)) base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]).repeat(5) - # TODO(rachelim): when this optimization works, turn on expect_optimized unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) + base_dataset, map_fn, expect_optimized=True) self.assertDatasetsEqual(optimized, unoptimized) def testOptimizationIgnoreStateful(self): |