aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-09 10:40:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 10:45:19 -0700
commit11f32ebbdcd4eaf5e9e09fe27571e26ec0bd9dd8 (patch)
tree032a600b39f926c9ec6ab62625b5a5fd03f20e87
parent1b4402137a76c8085c160edfcc0c3be3cfa8fa3a (diff)
[tf.data vectorization] Handle captured inputs in MapVectorization optimization
PiperOrigin-RevId: 216381943
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc31
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc35
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py9
3 files changed, 53 insertions, 22 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a9254ed58b..0576d075c2 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -60,14 +60,24 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
+ // Note that the inputs to the function are either regular arguments (for
+ // which the function is mapped across their 0th dimension) or captured inputs
+ // (for which the function takes the argument wholesale). We can infer
+ // the split between these arguments from the `map_node`'s attrs.
+ // The Targuments attr on `map_node` corresponds to a list of types of
+ // MapDataset's captured inputs.
+ auto t_captured = map_node.attr().at("Targuments");
+
// Get types of input arguments from original map function
- AttrValue t_args;
+ DataTypeVector t_args; // Regular arguments
for (const auto& input : vectorized_func->signature().input_arg()) {
- t_args.mutable_list()->add_type(input.type());
+ t_args.push_back(input.type());
map_defun_node->add_input(input.name());
}
- (*map_defun_node->mutable_attr())["Targuments"] = t_args;
- AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
+ // Erase the captured arguments from Targuments
+ t_args.erase(t_args.end() - t_captured.list().type_size(), t_args.end());
+ AddNodeAttr("Targuments", t_args, map_defun_node);
+ AddNodeAttr("Tcaptured", t_captured, map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
@@ -96,7 +106,9 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
*vectorized_func, map_defun_node, library, &result);
if (!s.ok()) {
- LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ LOG(WARNING) << "VectorizeMapDefun failed. The function will only be "
+ "naively vectorized with MapDefun. Reason: "
+ << s;
return vectorized_func;
}
return result;
@@ -129,10 +141,6 @@ bool IsStatefulFn(const FunctionLibraryDefinition& library,
return false;
}
-bool HasCapturedInputs(const NodeDef& map_node) {
- return map_node.attr().at("Targuments").list().type_size() > 0;
-}
-
NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
const NodeDef& input_node,
const FunctionDef& vectorized_func,
@@ -239,15 +247,12 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
// Check that this is a valid optimization.
if (!IsOutputShapesFullyDefined(*input_node) ||
!IsOutputShapesFullyDefined(*map_node) ||
- IsStatefulFn(function_library, *orig_func) ||
- HasCapturedInputs(*map_node)) {
+ IsStatefulFn(function_library, *orig_func)) {
// 1. If any of the inputs have an unknown shape, don't optimize, since
// inputs might not be batchable.
// 2. If any of the map func outputs have an unknown shape, don't
// optimize, so that batching errors surface as before.
// 3. If the function is stateful, don't vectorize it.
- // 4. TODO(rachelim): Make this work for MapDataset with captured inputs
- // by tiling inputs or modifying the signature of MapDefun.
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index ba857ab5d9..d977ff3198 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -253,8 +253,13 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
}
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- std::move(inputs), &outputs));
+ Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs);
+ if (!s.ok()) {
+ VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
+ << "\" failed with error: " << s;
+ return s;
+ }
if (op_node->num_outputs() != outputs.size()) {
return errors::Internal(
@@ -481,17 +486,37 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
}
Status Vectorization::AddArgNodeMappings() {
- for (auto arg_node : map_defun_fn_->arg_nodes) {
+ // Note that inputs to map_defun_fn_ are either regular arguments (for which
+ // the operations are mapped across their 0th dimension) or captured inputs
+ // (for which the operations apply to the argument wholesale).
+ int num_args =
+ map_defun_node_->attrs().Find("Targuments")->list().type_size();
+
+ auto add_conversion = [this](Node* arg_node, bool stacked) {
Node* input_node;
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});
// Control inputs
conversion_map_.insert({{arg_node, Graph::kControlSlot},
- {input_node, Graph::kControlSlot, true}});
+ {input_node, Graph::kControlSlot, stacked}});
+
+ return Status::OK();
+ };
+
+ // Regular arguments
+ for (int i = 0; i < num_args; ++i) {
+ TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
+ }
+
+ // Captured inputs. These are applied (without slicing) to every iteration of
+ // the map function, hence are mapped to unstacked nodes.
+ for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) {
+ TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
}
+
return Status::OK();
}
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index 971a2d94b9..803ff87924 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -105,15 +105,16 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def testOptimizationWithCapturedInputs(self):
# Tests that vectorization works with captured inputs
+ y = constant_op.constant(1, shape=(2,))
+ z = constant_op.constant(2, shape=(2,))
+
def map_fn(x):
- return x + y
+ return x, y, z
- y = constant_op.constant(1, shape=(2,))
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
- # TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
+ base_dataset, map_fn, expect_optimized=True)
self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):