From 950cf87104bfee28e2165fe368f66337b8a1336d Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Tue, 9 Oct 2018 14:36:33 -0700 Subject: [tf.data vectorization] Add vectorizer for `Add` op PiperOrigin-RevId: 216424512 --- .../grappler/optimizers/data/vectorization_utils.cc | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc') diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index d977ff3198..8b93b1f2b8 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -64,9 +64,18 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, } } +// Update node attrs to keep its properties consistent with the function +void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) { + map_defun_node->AddAttr("output_types", map_defun_fn->ret_types); + + // TODO(rachelim): Propagate precise shapes if they're known, which may enable + // subsequent optimizations. + map_defun_node->AddAttr("output_shapes", std::vector( + map_defun_fn->ret_types.size())); +} + Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, const TensorDesc& output) { - // Note that we don't update MapDefun attrs as we go, only when we are done DataType type = output.first->output_type(output.second); int index = map_defun_fn->ret_nodes.size(); @@ -83,13 +92,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); map_defun_fn->ret_nodes.push_back(ret_node); map_defun_fn->ret_types.push_back(type); + UpdateMapDefunAttrs(map_defun_fn, map_defun_node); return s; } void RemoveMapDefunOutput(int output_position, Graph* outer_scope, FunctionBody* map_defun_fn, Node* map_defun_node) { - // Note that we don't update MapDefun attrs as we go, only when we are done DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) << "Trying to remove output that doesn't exist. Output number: " << output_position; @@ -102,6 +111,7 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope, output_position); map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + output_position); + UpdateMapDefunAttrs(map_defun_fn, map_defun_node); // Renumber the nodes and edges that come after for (int i = 0; i < num_later_outputs; ++i) { @@ -342,13 +352,6 @@ void Vectorization::VectorizeHelper() { // need the MapDefun node and can delete it. if (map_defun_fn_->ret_nodes.empty()) { outer_scope_->RemoveNode(map_defun_node_); - } else { - // Update MapDefun node attrs accordingly - DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); - map_defun_node_->AddAttr( - "output_shapes", - std::vector(map_defun_fn_->ret_types.size())); - map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } } -- cgit v1.2.3