From e513d1f516046abc4a5831e1347720922118e81b Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 8 Jan 2018 17:21:43 -0800 Subject: Materialize BroadcastGradientArgs by default instead of just doing so in aggressive mode. This ensures that we optimize gradient computations in the presence of variable batch sizes. PiperOrigin-RevId: 181242749 --- .../core/grappler/optimizers/constant_folding.cc | 75 +++++++++++++++------- .../grappler/optimizers/constant_folding_test.cc | 11 +--- 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 9f24f1c768..68feedbcbb 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -433,13 +433,42 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( id = --min_id; } } + + // Beware: the reduction dimensions computed by the BCast class are valid iff + // we assume that two distinct symbolic dimensions can't be equal and a + // symbolic dimension can't be equal to 1. This is often but not always true, + // so to make this optimization safe we filter out these cases. + const int common_dims = std::min(shape1.size(), shape2.size()); + for (int i = 0; i < common_dims; ++i) { + if (shape1[i] >= 0 && shape2[i] >= 0) { + continue; + } + if (shape1[i] != shape2[i]) { + // We're either dealing with 2 different symbolic dimensions or a symbolic + // and a know dimensions. We can't be sure whether both are equal or not, + // so we can't be sure whether we'll be broadcasting or not. + return Status::OK(); + } + } + // These extra dims could be equal to 1, in which case there is no + // broadcasting. It could also be greater than 1, in which case there would + // be broadcasting. Since we don't know, we'll just punt. + for (int i = common_dims; i < shape1.size(); ++i) { + if (shape1[i] < 0) { + return Status::OK(); + } + } + for (int i = common_dims; i < shape2.size(); ++i) { + if (shape2[i] < 0) { + return Status::OK(); + } + } + BCast bcast(shape1, shape2); if (!bcast.IsValid()) { return Status::OK(); } - // Beware: the reduction dimensions are valid iff we assume that two distinct - // symbolic dimensions can't be equal. This is often but not always true, so - // this optimization isn't safe. + BCast::Vec reduce_dims[2]; reduce_dims[0] = bcast.grad_x_reduce_idx(); reduce_dims[1] = bcast.grad_y_reduce_idx(); @@ -447,26 +476,27 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( const DataType type = node.attr().at("T").type(); NodeDef* out[2]; for (int j = 0; j < 2; ++j) { - if (!reduce_dims[j].empty()) { - // This is the case when a tensor dimension of 1 is matched against an - // unknown dimension. The unknown dimension could also be equal to 1, in - // which case there would be no reduction. - out[j] = nullptr; - } else { - string const_name = OptimizedNodeName(node, strings::StrCat("-", j)); - out[j] = node_map_->GetNode(const_name); - if (out[j] == nullptr) { - out[j] = graph_->add_node(); - Tensor value(type, TensorShape({0})); - *out[j] = CreateNodeDef(const_name, TensorValue(&value)); - out[j]->set_device(node.device()); - node_map_->AddNode(const_name, out[j]); - string ctrl_dep = - AddControlDependency(node.name(), graph_, node_map_.get()); - *out[j]->add_input() = ctrl_dep; - node_map_->AddOutput(NodeName(ctrl_dep), const_name); + int reduction_indices = reduce_dims[j].size(); + Tensor value(type, TensorShape({reduction_indices})); + for (int i = 0; i < reduction_indices; ++i) { + if (type == DT_INT32) { + value.vec()(i) = reduce_dims[j][i]; + } else { + value.vec()(i) = reduce_dims[j][i]; } } + string const_name = OptimizedNodeName(node, strings::StrCat("-", j)); + out[j] = node_map_->GetNode(const_name); + if (out[j] == nullptr) { + out[j] = graph_->add_node(); + *out[j] = CreateNodeDef(const_name, TensorValue(&value)); + out[j]->set_device(node.device()); + node_map_->AddNode(const_name, out[j]); + string ctrl_dep = + AddControlDependency(node.name(), graph_, node_map_.get()); + *out[j]->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + } } const std::set outputs = node_map_->GetOutputs(node.name()); @@ -584,12 +614,11 @@ Status ConstantFolding::MaterializeReductionIndices( Status ConstantFolding::MaterializeConstants( const GraphProperties& properties) { - const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { NodeDef& node = *graph_->mutable_node(i); const string& op = node.op(); - if (is_aggressive && op == "BroadcastGradientArgs") { + if (op == "BroadcastGradientArgs") { TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); } else if (IsReduction(node)) { TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties)); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index a3b3e522eb..c53678f727 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1373,21 +1373,14 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { } else if (node.name() == "p1") { ++found; EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("ConstantFolding/i-0", node.input(0)); + EXPECT_EQ("i", node.input(0)); } else if (node.name() == "p2") { ++found; EXPECT_EQ(1, node.input_size()); EXPECT_EQ("i:1", node.input(0)); - } else if (node.name() == "ConstantFolding/i-0") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("^i", node.input(0)); - EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); } } - EXPECT_EQ(7, found); + EXPECT_EQ(6, found); } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { -- cgit v1.2.3