diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_vectorization.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_vectorization.cc | 31 |
1 files changed, 18 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index a9254ed58b..0576d075c2 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -60,14 +60,24 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, graph_utils::CopyAttribute(k, map_node, map_defun_node); } + // Note that the inputs to the function are either regular arguments (for + // which the function is mapped across their 0th dimension) or captured inputs + // (for which the function takes the argument wholesale). We can infer + // the split between these arguments from the `map_node`'s attrs. + // The Targuments attr on `map_node` corresponds to a list of types of + // MapDataset's captured inputs. + auto t_captured = map_node.attr().at("Targuments"); + // Get types of input arguments from original map function - AttrValue t_args; + DataTypeVector t_args; // Regular arguments for (const auto& input : vectorized_func->signature().input_arg()) { - t_args.mutable_list()->add_type(input.type()); + t_args.push_back(input.type()); map_defun_node->add_input(input.name()); } - (*map_defun_node->mutable_attr())["Targuments"] = t_args; - AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); + // Erase the captured arguments from Targuments + t_args.erase(t_args.end() - t_captured.list().type_size(), t_args.end()); + AddNodeAttr("Targuments", t_args, map_defun_node); + AddNodeAttr("Tcaptured", t_captured, map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); @@ -96,7 +106,9 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, *vectorized_func, map_defun_node, library, &result); if (!s.ok()) { - LOG(ERROR) << "VectorizeMapDefun failed: " << s; + LOG(WARNING) << "VectorizeMapDefun failed. The function will only be " + "naively vectorized with MapDefun. Reason: " + << s; return vectorized_func; } return result; @@ -129,10 +141,6 @@ bool IsStatefulFn(const FunctionLibraryDefinition& library, return false; } -bool HasCapturedInputs(const NodeDef& map_node) { - return map_node.attr().at("Targuments").list().type_size() > 0; -} - NodeDef MakeNewBatchNode(const NodeDef& old_batch_node, const NodeDef& input_node, const FunctionDef& vectorized_func, @@ -239,15 +247,12 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item, // Check that this is a valid optimization. if (!IsOutputShapesFullyDefined(*input_node) || !IsOutputShapesFullyDefined(*map_node) || - IsStatefulFn(function_library, *orig_func) || - HasCapturedInputs(*map_node)) { + IsStatefulFn(function_library, *orig_func)) { // 1. If any of the inputs have an unknown shape, don't optimize, since // inputs might not be batchable. // 2. If any of the map func outputs have an unknown shape, don't // optimize, so that batching errors surface as before. // 3. If the function is stateful, don't vectorize it. - // 4. TODO(rachelim): Make this work for MapDataset with captured inputs - // by tiling inputs or modifying the signature of MapDefun. continue; } |