aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc75
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc11
2 files changed, 54 insertions, 32 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 9f24f1c768..68feedbcbb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
id = --min_id;
}
}
+
+ // Beware: the reduction dimensions computed by the BCast class are valid iff
+ // we assume that two distinct symbolic dimensions can't be equal and a
+ // symbolic dimension can't be equal to 1. This is often but not always true,
+ // so to make this optimization safe we filter out these cases.
+ const int common_dims = std::min(shape1.size(), shape2.size());
+ for (int i = 0; i < common_dims; ++i) {
+ if (shape1[i] >= 0 && shape2[i] >= 0) {
+ continue;
+ }
+ if (shape1[i] != shape2[i]) {
+ // We're either dealing with 2 different symbolic dimensions or a symbolic
+ // and a know dimensions. We can't be sure whether both are equal or not,
+ // so we can't be sure whether we'll be broadcasting or not.
+ return Status::OK();
+ }
+ }
+ // These extra dims could be equal to 1, in which case there is no
+ // broadcasting. It could also be greater than 1, in which case there would
+ // be broadcasting. Since we don't know, we'll just punt.
+ for (int i = common_dims; i < shape1.size(); ++i) {
+ if (shape1[i] < 0) {
+ return Status::OK();
+ }
+ }
+ for (int i = common_dims; i < shape2.size(); ++i) {
+ if (shape2[i] < 0) {
+ return Status::OK();
+ }
+ }
+
BCast bcast(shape1, shape2);
if (!bcast.IsValid()) {
return Status::OK();
}
- // Beware: the reduction dimensions are valid iff we assume that two distinct
- // symbolic dimensions can't be equal. This is often but not always true, so
- // this optimization isn't safe.
+
BCast::Vec reduce_dims[2];
reduce_dims[0] = bcast.grad_x_reduce_idx();
reduce_dims[1] = bcast.grad_y_reduce_idx();
@@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const DataType type = node.attr().at("T").type();
NodeDef* out[2];
for (int j = 0; j < 2; ++j) {
- if (!reduce_dims[j].empty()) {
- // This is the case when a tensor dimension of 1 is matched against an
- // unknown dimension. The unknown dimension could also be equal to 1, in
- // which case there would be no reduction.
- out[j] = nullptr;
- } else {
- string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
- out[j] = node_map_->GetNode(const_name);
- if (out[j] == nullptr) {
- out[j] = graph_->add_node();
- Tensor value(type, TensorShape({0}));
- *out[j] = CreateNodeDef(const_name, TensorValue(&value));
- out[j]->set_device(node.device());
- node_map_->AddNode(const_name, out[j]);
- string ctrl_dep =
- AddControlDependency(node.name(), graph_, node_map_.get());
- *out[j]->add_input() = ctrl_dep;
- node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+ int reduction_indices = reduce_dims[j].size();
+ Tensor value(type, TensorShape({reduction_indices}));
+ for (int i = 0; i < reduction_indices; ++i) {
+ if (type == DT_INT32) {
+ value.vec<int32>()(i) = reduce_dims[j][i];
+ } else {
+ value.vec<int64>()(i) = reduce_dims[j][i];
}
}
+ string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
+ out[j] = node_map_->GetNode(const_name);
+ if (out[j] == nullptr) {
+ out[j] = graph_->add_node();
+ *out[j] = CreateNodeDef(const_name, TensorValue(&value));
+ out[j]->set_device(node.device());
+ node_map_->AddNode(const_name, out[j]);
+ string ctrl_dep =
+ AddControlDependency(node.name(), graph_, node_map_.get());
+ *out[j]->add_input() = ctrl_dep;
+ node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+ }
}
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
@@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices(
Status ConstantFolding::MaterializeConstants(
const GraphProperties& properties) {
- const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
NodeDef& node = *graph_->mutable_node(i);
const string& op = node.op();
- if (is_aggressive && op == "BroadcastGradientArgs") {
+ if (op == "BroadcastGradientArgs") {
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
} else if (IsReduction(node)) {
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index a3b3e522eb..c53678f727 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
} else if (node.name() == "p1") {
++found;
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ConstantFolding/i-0", node.input(0));
+ EXPECT_EQ("i", node.input(0));
} else if (node.name() == "p2") {
++found;
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("i:1", node.input(0));
- } else if (node.name() == "ConstantFolding/i-0") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("^i", node.input(0));
- EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
}
}
- EXPECT_EQ(7, found);
+ EXPECT_EQ(6, found);
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {