diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-10 02:42:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-10 02:47:15 -0700 |
commit | d6a3d6a8295359364c86aecc479e6392bcde0ce4 (patch) | |
tree | 98658454a85871179cf61e734d2edeb4abab024a /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | |
parent | dd7d31fa7bfa357e58987c2f3881d99c8050b6de (diff) |
Automated rollback of commit 950cf87104bfee28e2165fe368f66337b8a1336d
PiperOrigin-RevId: 216500702
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | 21 |
1 files changed, 9 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 8b93b1f2b8..d977ff3198 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -64,18 +64,9 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, } } -// Update node attrs to keep its properties consistent with the function -void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) { - map_defun_node->AddAttr("output_types", map_defun_fn->ret_types); - - // TODO(rachelim): Propagate precise shapes if they're known, which may enable - // subsequent optimizations. - map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>( - map_defun_fn->ret_types.size())); -} - Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, const TensorDesc& output) { + // Note that we don't update MapDefun attrs as we go, only when we are done DataType type = output.first->output_type(output.second); int index = map_defun_fn->ret_nodes.size(); @@ -92,13 +83,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); map_defun_fn->ret_nodes.push_back(ret_node); map_defun_fn->ret_types.push_back(type); - UpdateMapDefunAttrs(map_defun_fn, map_defun_node); return s; } void RemoveMapDefunOutput(int output_position, Graph* outer_scope, FunctionBody* map_defun_fn, Node* map_defun_node) { + // Note that we don't update MapDefun attrs as we go, only when we are done DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) << "Trying to remove output that doesn't exist. Output number: " << output_position; @@ -111,7 +102,6 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope, output_position); map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + output_position); - UpdateMapDefunAttrs(map_defun_fn, map_defun_node); // Renumber the nodes and edges that come after for (int i = 0; i < num_later_outputs; ++i) { @@ -352,6 +342,13 @@ void Vectorization::VectorizeHelper() { // need the MapDefun node and can delete it. if (map_defun_fn_->ret_nodes.empty()) { outer_scope_->RemoveNode(map_defun_node_); + } else { + // Update MapDefun node attrs accordingly + DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); + map_defun_node_->AddAttr( + "output_shapes", + std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size())); + map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } } |