aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-06 09:51:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-06 09:59:41 -0800
commit673827730a3ae6a86f9cad86d19ec9e7d4597f3a (patch)
tree211e80527f65abc0efc1085fde8af4420f58c2aa
parent4080654c8f03ec34f2822c14db5fd8b75f63d569 (diff)
Correctly connect control dependencies to switch nodes when doing arithmetic
simplifications PiperOrigin-RevId: 181043095
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h7
2 files changed, 22 insertions, 15 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index d5259bf177..9f24f1c768 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1276,26 +1276,31 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
}
void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward,
- NodeDef* node) {
+ NodeDef* node,
+ GraphDef* graph) {
node->set_op("Identity");
// Propagate the designated input through the identity.
node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies.
for (int i = 1; i < node->input_size(); ++i) {
- node->set_input(i, AsControlDependency(node->input(i)));
+ node->set_input(
+ i, AddControlDependency(node->input(i), graph, node_map_.get()));
}
graph_modified_ = true;
}
-void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node) {
+void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
+ GraphDef* graph) {
node->set_op("Reciprocal");
node->mutable_input()->SwapElements(0, 1);
- node->set_input(1, AsControlDependency(node->input(1)));
+ node->set_input(1,
+ AddControlDependency(node->input(1), graph, node_map_.get()));
graph_modified_ = true;
}
Status ConstantFolding::ReplaceOperationWithConstant(
- double value, const TensorShapeProto& shape, NodeDef* node) {
+ double value, const TensorShapeProto& shape, NodeDef* node,
+ GraphDef* graph) {
AttrValue tensor_attr;
AttrValue dtype_attr = node->attr().at("T");
TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value,
@@ -1309,7 +1314,8 @@ Status ConstantFolding::ReplaceOperationWithConstant(
if (IsControlInput(node->input(i))) {
break;
}
- node->set_input(i, AsControlDependency(node->input(i)));
+ node->set_input(
+ i, AddControlDependency(node->input(i), graph, node_map_.get()));
}
graph_modified_ = true;
return Status::OK();
@@ -1380,7 +1386,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// TODO(rmlarsen): Handle subtraction 0 - y.
// 1 * y = y or 0 + y = y.
- ReplaceOperationWithIdentity(1, node);
+ ReplaceOperationWithIdentity(1, node, output);
continue;
}
@@ -1388,7 +1394,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
if (y_matches_output_shape && is_any_div && x_is_one) {
DataType type = node->attr().at("T").type();
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
- ReplaceDivisionOfOnesByReciprocal(node);
+ ReplaceDivisionOfOnesByReciprocal(node, output);
continue;
}
}
@@ -1402,7 +1408,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
(((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero && is_aggressive))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
- ReplaceOperationWithIdentity(0, node);
+ ReplaceOperationWithIdentity(0, node, output);
continue;
}
@@ -1416,17 +1422,17 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) {
TF_RETURN_IF_ERROR(
- ReplaceOperationWithConstant(0, output_shape, node));
+ ReplaceOperationWithConstant(0, output_shape, node, output));
continue;
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero
// input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
- ReplaceOperationWithIdentity(0, node);
+ ReplaceOperationWithIdentity(0, node, output);
continue;
} else if (is_mul && y_is_zero && y_matches_output_shape) {
- ReplaceOperationWithIdentity(1, node);
+ ReplaceOperationWithIdentity(1, node, output);
continue;
}
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 87f275c1c0..6aadd97508 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -78,11 +78,12 @@ class ConstantFolding : public GraphOptimizer {
bool IsOnes(const NodeDef& node) const;
bool IsZeros(const NodeDef& node) const;
- void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node);
+ void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node,
+ GraphDef* graph);
Status ReplaceOperationWithConstant(double value,
const TensorShapeProto& shape,
- NodeDef* node);
- void ReplaceDivisionOfOnesByReciprocal(NodeDef* node);
+ NodeDef* node, GraphDef* graph);
+ void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
Status FoldGraph(GraphDef* output);
bool IsSimplifiableReduction(const NodeDef& node) const;