diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | 116 |
1 files changed, 55 insertions, 61 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 344c420902..ba857ab5d9 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -45,22 +45,6 @@ namespace { // Describes a tensor with its operation Node and output position typedef std::pair<Node*, int> TensorDesc; -// Equivalent to python Pfor's WrappedTensor struct -struct WrappedTensor { - TensorDesc tensor; - - // Whether the tensor is stacked, i.e. represents the results of applying - // the operation on all slices of the input, where each row i of the - // tensor corresponds to the op's output on slice i of the input. False - // if the tensor is not stacked, i.e. represents the result of the op on - // a single slice of the input, where the result does not vary between - // slices. - bool stacked; - - WrappedTensor(TensorDesc&& tensor, bool stacked) - : tensor(std::move(tensor)), stacked(stacked) {} -}; - const char* const kRetValOp = "_Retval"; void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, @@ -239,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) { return errors::Unimplemented("No vectorizer registered for op: ", op_node->type_string()); } - std::vector<Port> input_ports, output_ports; - input_ports.reserve(op_node->num_inputs()); - output_ports.reserve(op_node->num_outputs()); - TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), - &input_ports, &output_ports)); + std::vector<WrappedTensor> inputs, outputs; + inputs.reserve(op_node->num_inputs()); + outputs.reserve(op_node->num_outputs()); std::vector<const Edge*> input_edges; TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); - if (op_node->num_outputs() != output_ports.size() || - op_node->num_inputs() != input_ports.size() || - input_edges.size() != input_ports.size()) { - return errors::Internal("Vectorizer inputs/outputs don't match."); - } - - // Promote the inputs of the op to MapDefun outputs and connect the edges - // accordingly. + // The inputs for the node to be converted may already have been converted + // themselves. For those that are not, we promote them to MapDefun outputs. for (size_t i = 0; i < op_node->num_inputs(); ++i) { auto edge = input_edges[i]; - TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, - {edge->src(), edge->src_output()})); - outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1, - input_ports[i].first, input_ports[i].second); + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + inputs.push_back(*found); + } else { + // TODO(rachelim): Handle the case where unconverted inputs are unstacked. + // We assume that all unconverted inputs will be stacked, since we + // converted all unstacked nodes in `Initialize`. However, it's actually + // possible that yet-unconverted nodes may produce unstacked outputs after + // they are vectorized. (For example, see the "Shape" converter in + // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects + // an unstacked input but receives a stacked one, vectorizer->Vectorize + // will return an error. + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + int output_index = map_defun_fn_->ret_nodes.size() - 1; + inputs.push_back({map_defun_node_, output_index, true}); + } + } + + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs)); + + if (op_node->num_outputs() != outputs.size()) { + return errors::Internal( + "Number of vectorizer outputs does not match. Expected: ", + op_node->num_outputs(), " Actual: ", outputs.size()); } // Add output mappings. for (size_t i = 0; i < op_node->num_outputs(); ++i) { - conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}}); + conversion_map_.insert({{op_node, i}, outputs[i]}); } return Status::OK(); @@ -281,25 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) { TensorDesc output({ret_edge->src(), ret_edge->src_output()}); TensorDesc converted_output; - if (auto found = gtl::FindOrNull(conversion_map_, output)) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - if (found->stacked) { - converted_output = found->tensor; - } else { - // Some outputs may be unstacked if they don't derive from arg nodes - // (for example, if a function returns a constant). For these, we - // have to add extra nodes to tile it in the 0th dimension. - TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); - } - } else { - // Note: All unstacked nodes are converted ahead of time in `Initialize`, - // and here we assume that all op vectorizers create only stacked outputs. - // This may not hold in the future, as more vectorizers are added that - // may actually create unstacked outputs. For example, see the `Shape` - // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py + + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + auto found = gtl::FindOrNull(conversion_map_, output); + if (!found) { TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); - converted_output = conversion_map_.at(output).tensor; + found = &conversion_map_.at(output); + } + + if (found->stacked) { + converted_output = {found->node, found->output_index}; + } else { + // Some outputs may be unstacked if they don't derive from arg nodes + // (for example, if a function returns a constant). For these, we + // have to add extra nodes to tile it in the 0th dimension. + TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); } ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, @@ -455,7 +450,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, Node* ones_shape; TF_RETURN_IF_ERROR(node_builder("Shape") - .Input(unstacked->tensor.first) // input + .Input(unstacked->node) // input .Finalize(g, &ones_shape)); Node* ones; @@ -473,8 +468,8 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked, Node* expand_dims; TF_RETURN_IF_ERROR(node_builder("ExpandDims") - .Input(unstacked->tensor.first) // input - .Input(const_0) // dim + .Input(unstacked->node) // input + .Input(const_0) // dim .Finalize(g, &expand_dims)); TF_RETURN_IF_ERROR(node_builder("Tile") @@ -491,11 +486,11 @@ Status Vectorization::AddArgNodeMappings() { TF_RETURN_IF_ERROR(map_defun_node_->input_node( arg_node->attrs().Find("index")->i(), &input_node)); - conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}}); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); // Control inputs conversion_map_.insert({{arg_node, Graph::kControlSlot}, - {{input_node, Graph::kControlSlot}, true}}); + {input_node, Graph::kControlSlot, true}}); } return Status::OK(); } @@ -541,7 +536,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, if (auto found = gtl::FindOrNull(conversion_map_, {edge->src(), edge->src_output()})) { - outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node, + outer_scope_->AddEdge(found->node, found->output_index, node, edge->dst_input()); } else { status->Update(errors::Internal( @@ -552,11 +547,10 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, // Add output mappings for (int i = 0; i < tensor.first->num_outputs(); ++i) { - conversion_map_.insert( - {{tensor.first, i}, WrappedTensor({node, i}, false)}); + conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)}); } conversion_map_.insert({{tensor.first, Graph::kControlSlot}, - WrappedTensor({node, Graph::kControlSlot}, false)}); + WrappedTensor(node, Graph::kControlSlot, false)}); return true; } |