From 673827730a3ae6a86f9cad86d19ec9e7d4597f3a Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Sat, 6 Jan 2018 09:51:38 -0800 Subject: Correctly connect control dependencies to switch nodes when doing arithmetic simplifications PiperOrigin-RevId: 181043095 --- .../core/grappler/optimizers/constant_folding.cc | 30 +++++++++++++--------- .../core/grappler/optimizers/constant_folding.h | 7 ++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index d5259bf177..9f24f1c768 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1276,26 +1276,31 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { } void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, - NodeDef* node) { + NodeDef* node, + GraphDef* graph) { node->set_op("Identity"); // Propagate the designated input through the identity. node->mutable_input()->SwapElements(0, input_to_forward); // Add all other inputs as control dependencies. for (int i = 1; i < node->input_size(); ++i) { - node->set_input(i, AsControlDependency(node->input(i))); + node->set_input( + i, AddControlDependency(node->input(i), graph, node_map_.get())); } graph_modified_ = true; } -void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node) { +void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, + GraphDef* graph) { node->set_op("Reciprocal"); node->mutable_input()->SwapElements(0, 1); - node->set_input(1, AsControlDependency(node->input(1))); + node->set_input(1, + AddControlDependency(node->input(1), graph, node_map_.get())); graph_modified_ = true; } Status ConstantFolding::ReplaceOperationWithConstant( - double value, const TensorShapeProto& shape, NodeDef* node) { + double value, const TensorShapeProto& shape, NodeDef* node, + GraphDef* graph) { AttrValue tensor_attr; AttrValue dtype_attr = node->attr().at("T"); TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value, @@ -1309,7 +1314,8 @@ Status ConstantFolding::ReplaceOperationWithConstant( if (IsControlInput(node->input(i))) { break; } - node->set_input(i, AsControlDependency(node->input(i))); + node->set_input( + i, AddControlDependency(node->input(i), graph, node_map_.get())); } graph_modified_ = true; return Status::OK(); @@ -1380,7 +1386,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, ((is_mul && x_is_one) || (is_add && x_is_zero))) { // TODO(rmlarsen): Handle subtraction 0 - y. // 1 * y = y or 0 + y = y. - ReplaceOperationWithIdentity(1, node); + ReplaceOperationWithIdentity(1, node, output); continue; } @@ -1388,7 +1394,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, if (y_matches_output_shape && is_any_div && x_is_one) { DataType type = node->attr().at("T").type(); if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) { - ReplaceDivisionOfOnesByReciprocal(node); + ReplaceDivisionOfOnesByReciprocal(node, output); continue; } } @@ -1402,7 +1408,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero && is_aggressive))) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithIdentity(0, node); + ReplaceOperationWithIdentity(0, node, output); continue; } @@ -1416,17 +1422,17 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const PartialTensorShape shp(output_shape); if (shp.IsFullyDefined()) { TF_RETURN_IF_ERROR( - ReplaceOperationWithConstant(0, output_shape, node)); + ReplaceOperationWithConstant(0, output_shape, node, output)); continue; } // Even if an input shape is only partially known, we may known that it // matches the output shape and thus forward the corresponding zero // input. if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { - ReplaceOperationWithIdentity(0, node); + ReplaceOperationWithIdentity(0, node, output); continue; } else if (is_mul && y_is_zero && y_matches_output_shape) { - ReplaceOperationWithIdentity(1, node); + ReplaceOperationWithIdentity(1, node, output); continue; } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 87f275c1c0..6aadd97508 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -78,11 +78,12 @@ class ConstantFolding : public GraphOptimizer { bool IsOnes(const NodeDef& node) const; bool IsZeros(const NodeDef& node) const; - void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node); + void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node, + GraphDef* graph); Status ReplaceOperationWithConstant(double value, const TensorShapeProto& shape, - NodeDef* node); - void ReplaceDivisionOfOnesByReciprocal(NodeDef* node); + NodeDef* node, GraphDef* graph); + void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); Status FoldGraph(GraphDef* output); bool IsSimplifiableReduction(const NodeDef& node) const; -- cgit v1.2.3