aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc21
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index d977ff3198..8b93b1f2b8 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -64,9 +64,18 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
}
}
+// Update node attrs to keep its properties consistent with the function
+void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
+ map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
+
+ // TODO(rachelim): Propagate precise shapes if they're known, which may enable
+ // subsequent optimizations.
+ map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
+ map_defun_fn->ret_types.size()));
+}
+
Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
const TensorDesc& output) {
- // Note that we don't update MapDefun attrs as we go, only when we are done
DataType type = output.first->output_type(output.second);
int index = map_defun_fn->ret_nodes.size();
@@ -83,13 +92,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
map_defun_fn->ret_nodes.push_back(ret_node);
map_defun_fn->ret_types.push_back(type);
+ UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
return s;
}
void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
FunctionBody* map_defun_fn, Node* map_defun_node) {
- // Note that we don't update MapDefun attrs as we go, only when we are done
DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
<< "Trying to remove output that doesn't exist. Output number: "
<< output_position;
@@ -102,6 +111,7 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
output_position);
map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
output_position);
+ UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
// Renumber the nodes and edges that come after
for (int i = 0; i < num_later_outputs; ++i) {
@@ -342,13 +352,6 @@ void Vectorization::VectorizeHelper() {
// need the MapDefun node and can delete it.
if (map_defun_fn_->ret_nodes.empty()) {
outer_scope_->RemoveNode(map_defun_node_);
- } else {
- // Update MapDefun node attrs accordingly
- DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
- map_defun_node_->AddAttr(
- "output_shapes",
- std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
- map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}