aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/map_vectorization.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc31
1 files changed, 18 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a9254ed58b..0576d075c2 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -60,14 +60,24 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
+ // Note that the inputs to the function are either regular arguments (for
+ // which the function is mapped across their 0th dimension) or captured inputs
+ // (for which the function takes the argument wholesale). We can infer
+ // the split between these arguments from the `map_node`'s attrs.
+ // The Targuments attr on `map_node` corresponds to a list of types of
+ // MapDataset's captured inputs.
+ auto t_captured = map_node.attr().at("Targuments");
+
// Get types of input arguments from original map function
- AttrValue t_args;
+ DataTypeVector t_args; // Regular arguments
for (const auto& input : vectorized_func->signature().input_arg()) {
- t_args.mutable_list()->add_type(input.type());
+ t_args.push_back(input.type());
map_defun_node->add_input(input.name());
}
- (*map_defun_node->mutable_attr())["Targuments"] = t_args;
- AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
+ // Erase the captured arguments from Targuments
+ t_args.erase(t_args.end() - t_captured.list().type_size(), t_args.end());
+ AddNodeAttr("Targuments", t_args, map_defun_node);
+ AddNodeAttr("Tcaptured", t_captured, map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
@@ -96,7 +106,9 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
*vectorized_func, map_defun_node, library, &result);
if (!s.ok()) {
- LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ LOG(WARNING) << "VectorizeMapDefun failed. The function will only be "
+ "naively vectorized with MapDefun. Reason: "
+ << s;
return vectorized_func;
}
return result;
@@ -129,10 +141,6 @@ bool IsStatefulFn(const FunctionLibraryDefinition& library,
return false;
}
-bool HasCapturedInputs(const NodeDef& map_node) {
- return map_node.attr().at("Targuments").list().type_size() > 0;
-}
-
NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
const NodeDef& input_node,
const FunctionDef& vectorized_func,
@@ -239,15 +247,12 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
// Check that this is a valid optimization.
if (!IsOutputShapesFullyDefined(*input_node) ||
!IsOutputShapesFullyDefined(*map_node) ||
- IsStatefulFn(function_library, *orig_func) ||
- HasCapturedInputs(*map_node)) {
+ IsStatefulFn(function_library, *orig_func)) {
// 1. If any of the inputs have an unknown shape, don't optimize, since
// inputs might not be batchable.
// 2. If any of the map func outputs have an unknown shape, don't
// optimize, so that batching errors surface as before.
// 3. If the function is stateful, don't vectorize it.
- // 4. TODO(rachelim): Make this work for MapDataset with captured inputs
- // by tiling inputs or modifying the signature of MapDefun.
continue;
}