aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc21
1 files changed, 9 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 8b93b1f2b8..d977ff3198 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -64,18 +64,9 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
}
}
-// Update node attrs to keep its properties consistent with the function
-void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
- map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
-
- // TODO(rachelim): Propagate precise shapes if they're known, which may enable
- // subsequent optimizations.
- map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
- map_defun_fn->ret_types.size()));
-}
-
Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
DataType type = output.first->output_type(output.second);
int index = map_defun_fn->ret_nodes.size();
@@ -92,13 +83,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
map_defun_fn->ret_nodes.push_back(ret_node);
map_defun_fn->ret_types.push_back(type);
- UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
return s;
}
void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
<< "Trying to remove output that doesn't exist. Output number: "
<< output_position;
@@ -111,7 +102,6 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
output_position);
map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
output_position);
- UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
// Renumber the nodes and edges that come after
for (int i = 0; i < num_later_outputs; ++i) {
@@ -352,6 +342,13 @@ void Vectorization::VectorizeHelper() {
// need the MapDefun node and can delete it.
if (map_defun_fn_->ret_nodes.empty()) {
outer_scope_->RemoveNode(map_defun_node_);
+ } else {
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}