aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc116
1 files changed, 55 insertions, 61 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 344c420902..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -45,22 +45,6 @@ namespace {
// Describes a tensor with its operation Node and output position
typedef std::pair<Node*, int> TensorDesc;
-// Equivalent to python Pfor's WrappedTensor struct
-struct WrappedTensor {
- TensorDesc tensor;
-
- // Whether the tensor is stacked, i.e. represents the results of applying
- // the operation on all slices of the input, where each row i of the
- // tensor corresponds to the op's output on slice i of the input. False
- // if the tensor is not stacked, i.e. represents the result of the op on
- // a single slice of the input, where the result does not vary between
- // slices.
- bool stacked;
-
- WrappedTensor(TensorDesc&& tensor, bool stacked)
- : tensor(std::move(tensor)), stacked(stacked) {}
-};
-
const char* const kRetValOp = "_Retval";
void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
@@ -239,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
return errors::Unimplemented("No vectorizer registered for op: ",
op_node->type_string());
}
- std::vector<Port> input_ports, output_ports;
- input_ports.reserve(op_node->num_inputs());
- output_ports.reserve(op_node->num_outputs());
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- &input_ports, &output_ports));
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
- if (op_node->num_outputs() != output_ports.size() ||
- op_node->num_inputs() != input_ports.size() ||
- input_edges.size() != input_ports.size()) {
- return errors::Internal("Vectorizer inputs/outputs don't match.");
- }
-
- // Promote the inputs of the op to MapDefun outputs and connect the edges
- // accordingly.
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
for (size_t i = 0; i < op_node->num_inputs(); ++i) {
auto edge = input_edges[i];
- TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
- {edge->src(), edge->src_output()}));
- outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
- input_ports[i].first, input_ports[i].second);
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
+
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}});
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
return Status::OK();
@@ -281,25 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) {
TensorDesc output({ret_edge->src(), ret_edge->src_output()});
TensorDesc converted_output;
- if (auto found = gtl::FindOrNull(conversion_map_, output)) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (found->stacked) {
- converted_output = found->tensor;
- } else {
- // Some outputs may be unstacked if they don't derive from arg nodes
- // (for example, if a function returns a constant). For these, we
- // have to add extra nodes to tile it in the 0th dimension.
- TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
- }
- } else {
- // Note: All unstacked nodes are converted ahead of time in `Initialize`,
- // and here we assume that all op vectorizers create only stacked outputs.
- // This may not hold in the future, as more vectorizers are added that
- // may actually create unstacked outputs. For example, see the `Shape`
- // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output).tensor;
+ found = &conversion_map_.at(output);
+ }
+
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -455,7 +450,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* ones_shape;
TF_RETURN_IF_ERROR(node_builder("Shape")
- .Input(unstacked->tensor.first) // input
+ .Input(unstacked->node) // input
.Finalize(g, &ones_shape));
Node* ones;
@@ -473,8 +468,8 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* expand_dims;
TF_RETURN_IF_ERROR(node_builder("ExpandDims")
- .Input(unstacked->tensor.first) // input
- .Input(const_0) // dim
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
.Finalize(g, &expand_dims));
TF_RETURN_IF_ERROR(node_builder("Tile")
@@ -491,11 +486,11 @@ Status Vectorization::AddArgNodeMappings() {
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
// Control inputs
conversion_map_.insert({{arg_node, Graph::kControlSlot},
- {{input_node, Graph::kControlSlot}, true}});
+ {input_node, Graph::kControlSlot, true}});
}
return Status::OK();
}
@@ -541,7 +536,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
if (auto found = gtl::FindOrNull(conversion_map_,
{edge->src(), edge->src_output()})) {
- outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node,
+ outer_scope_->AddEdge(found->node, found->output_index, node,
edge->dst_input());
} else {
status->Update(errors::Internal(
@@ -552,11 +547,10 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
// Add output mappings
for (int i = 0; i < tensor.first->num_outputs(); ++i) {
- conversion_map_.insert(
- {{tensor.first, i}, WrappedTensor({node, i}, false)});
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
}
conversion_map_.insert({{tensor.first, Graph::kControlSlot},
- WrappedTensor({node, Graph::kControlSlot}, false)});
+ WrappedTensor(node, Graph::kControlSlot, false)});
return true;
}