diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index ba857ab5d9..d977ff3198 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -253,8 +253,13 @@ Status Vectorization::AddConversionMapping(Node* op_node) { } } - TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), - std::move(inputs), &outputs)); + Status s = vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs); + if (!s.ok()) { + VLOG(2) << "Vectorizer for op \"" << op_node->type_string() + << "\" failed with error: " << s; + return s; + } if (op_node->num_outputs() != outputs.size()) { return errors::Internal( @@ -481,17 +486,37 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, } Status Vectorization::AddArgNodeMappings() { - for (auto arg_node : map_defun_fn_->arg_nodes) { + // Note that inputs to map_defun_fn_ are either regular arguments (for which + // the operations are mapped across their 0th dimension) or captured inputs + // (for which the operations apply to the argument wholesale). + int num_args = + map_defun_node_->attrs().Find("Targuments")->list().type_size(); + + auto add_conversion = [this](Node* arg_node, bool stacked) { Node* input_node; TF_RETURN_IF_ERROR(map_defun_node_->input_node( arg_node->attrs().Find("index")->i(), &input_node)); - conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, stacked}}); // Control inputs conversion_map_.insert({{arg_node, Graph::kControlSlot}, - {input_node, Graph::kControlSlot, true}}); + {input_node, Graph::kControlSlot, stacked}}); + + return Status::OK(); + }; + + // Regular arguments + for (int i = 0; i < num_args; ++i) { + TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], true)); + } + + // Captured inputs. These are applied (without slicing) to every iteration of + // the map function, hence are mapped to unstacked nodes. + for (int i = num_args; i < map_defun_fn_->arg_nodes.size(); ++i) { + TF_RETURN_IF_ERROR(add_conversion(map_defun_fn_->arg_nodes[i], false)); } + return Status::OK(); } |