aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-08 17:21:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 17:25:13 -0800
commite513d1f516046abc4a5831e1347720922118e81b (patch)
treeba369e818774b6c4709fe3ba160d347702f865e5
parente523d9370ce36e50ba64f9dd1260eaddbbb8244e (diff)
Materialize BroadcastGradientArgs by default instead of just doing so in aggressive mode. This ensures that we optimize gradient computations in the presence of variable batch sizes.
PiperOrigin-RevId: 181242749
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc75
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc11
2 files changed, 54 insertions, 32 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 9f24f1c768..68feedbcbb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
id = --min_id;
}
}
+
+ // Beware: the reduction dimensions computed by the BCast class are valid iff
+ // we assume that two distinct symbolic dimensions can't be equal and a
+ // symbolic dimension can't be equal to 1. This is often but not always true,
+ // so to make this optimization safe we filter out these cases.
+ const int common_dims = std::min(shape1.size(), shape2.size());
+ for (int i = 0; i < common_dims; ++i) {
+ if (shape1[i] >= 0 && shape2[i] >= 0) {
+ continue;
+ }
+ if (shape1[i] != shape2[i]) {
+ // We're either dealing with 2 different symbolic dimensions or a symbolic
+ // and a know dimensions. We can't be sure whether both are equal or not,
+ // so we can't be sure whether we'll be broadcasting or not.
+ return Status::OK();
+ }
+ }
+ // These extra dims could be equal to 1, in which case there is no
+ // broadcasting. It could also be greater than 1, in which case there would
+ // be broadcasting. Since we don't know, we'll just punt.
+ for (int i = common_dims; i < shape1.size(); ++i) {
+ if (shape1[i] < 0) {
+ return Status::OK();
+ }
+ }
+ for (int i = common_dims; i < shape2.size(); ++i) {
+ if (shape2[i] < 0) {
+ return Status::OK();
+ }
+ }
+
BCast bcast(shape1, shape2);
if (!bcast.IsValid()) {
return Status::OK();
}
- // Beware: the reduction dimensions are valid iff we assume that two distinct
- // symbolic dimensions can't be equal. This is often but not always true, so
- // this optimization isn't safe.
+
BCast::Vec reduce_dims[2];
reduce_dims[0] = bcast.grad_x_reduce_idx();
reduce_dims[1] = bcast.grad_y_reduce_idx();
@@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const DataType type = node.attr().at("T").type();
NodeDef* out[2];
for (int j = 0; j < 2; ++j) {
- if (!reduce_dims[j].empty()) {
- // This is the case when a tensor dimension of 1 is matched against an
- // unknown dimension. The unknown dimension could also be equal to 1, in
- // which case there would be no reduction.
- out[j] = nullptr;
- } else {
- string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
- out[j] = node_map_->GetNode(const_name);
- if (out[j] == nullptr) {
- out[j] = graph_->add_node();
- Tensor value(type, TensorShape({0}));
- *out[j] = CreateNodeDef(const_name, TensorValue(&value));
- out[j]->set_device(node.device());
- node_map_->AddNode(const_name, out[j]);
- string ctrl_dep =
- AddControlDependency(node.name(), graph_, node_map_.get());
- *out[j]->add_input() = ctrl_dep;
- node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+ int reduction_indices = reduce_dims[j].size();
+ Tensor value(type, TensorShape({reduction_indices}));
+ for (int i = 0; i < reduction_indices; ++i) {
+ if (type == DT_INT32) {
+ value.vec<int32>()(i) = reduce_dims[j][i];
+ } else {
+ value.vec<int64>()(i) = reduce_dims[j][i];
}
}
+ string const_name = OptimizedNodeName(node, strings::StrCat("-", j));
+ out[j] = node_map_->GetNode(const_name);
+ if (out[j] == nullptr) {
+ out[j] = graph_->add_node();
+ *out[j] = CreateNodeDef(const_name, TensorValue(&value));
+ out[j]->set_device(node.device());
+ node_map_->AddNode(const_name, out[j]);
+ string ctrl_dep =
+ AddControlDependency(node.name(), graph_, node_map_.get());
+ *out[j]->add_input() = ctrl_dep;
+ node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+ }
}
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
@@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices(
Status ConstantFolding::MaterializeConstants(
const GraphProperties& properties) {
- const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
NodeDef& node = *graph_->mutable_node(i);
const string& op = node.op();
- if (is_aggressive && op == "BroadcastGradientArgs") {
+ if (op == "BroadcastGradientArgs") {
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
} else if (IsReduction(node)) {
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index a3b3e522eb..c53678f727 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
} else if (node.name() == "p1") {
++found;
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ConstantFolding/i-0", node.input(0));
+ EXPECT_EQ("i", node.input(0));
} else if (node.name() == "p2") {
++found;
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("i:1", node.input(0));
- } else if (node.name() == "ConstantFolding/i-0") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("^i", node.input(0));
- EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
}
}
- EXPECT_EQ(7, found);
+ EXPECT_EQ(6, found);
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {