aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc35
1 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index ba857ab5d9..d977ff3198 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -253,8 +253,13 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
}
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- std::move(inputs), &outputs));
+ Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs);
+ if (!s.ok()) {
+ VLOG(2) << "Vectorizer for op \"" << op_node->type_string()
+ << "\" failed with error: " << s;
+ return s;
+ }
if (op_node->num_outputs() != outputs.size()) {
return errors::Internal(
@@ -481,17 +486,37 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
}
Status Vectorization::AddArgNodeMappings() {
- for (auto arg_node : map_defun_fn_->arg_nodes) {
+ // Note that inputs to map_defun_fn_ are either regular arguments (for which
+ // the operations are mapped across their 0th dimension) or captured inputs
+ // (for which the operations apply to the argument wholesale).
+ int num_args =
+ map_defun_node_->attrs().Find("Targuments")->list().type_size();
+
+ auto add_conversion = [this](Node* arg_node, bool stacked) {
Node* input_node;
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}});
// Control inputs
conversion_map_.insert({{arg_node, Graph::kControlSlot},
- {input_node, Graph::kControlSlot, true}});
+ {input_node, Graph::kControlSlot, stacked}});
+
+ return Status::OK();
+ };
+
+ // Regular arguments
+ for (int i = 0; i < num_args; ++i) {
+ TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true));
+ }
+
+ // Captured inputs. These are applied (without slicing) to every iteration of
+ // the map function, hence are mapped to unstacked nodes.
+ for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) {
+ TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false));
}
+
return Status::OK();
}