diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-10 10:22:07 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:42 -0800 |
commit | 51889acee1a266478b578afad3fbe7b3a90fc17a (patch) | |
tree | edbbb81c4f685dee7fd51ab7c40576b184849c89 /tensorflow | |
parent | 10d1827987b0eca4d0e6f8f56506c93c67e03f83 (diff) |
Add suffix to newly created Mul op in the optimizer to avoid the name collision
when two Conv2D objects depend on the same Const.
PiperOrigin-RevId: 175305425
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 46 |
2 files changed, 47 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index f2277a9b79..e8ef0e94b5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -794,7 +794,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( scale_tensor.tensor_shape().dim_size() == 0) { // Create new node `scaled_weights`. NodeDef* scaled_weights = graph_def->add_node(); - scaled_weights->set_name(weights->name() + "_scaled"); + scaled_weights->set_name(weights->name() + "_scaled_" + + conv->name()); scaled_weights->set_op("Mul"); scaled_weights->set_device(weights->device()); (*scaled_weights->mutable_attr())["T"] = diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 60fb47f51a..4fcbb0120e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -887,7 +887,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { CHECK_NOTNULL(node_map.GetNode("Transpose_uint8")); const NodeDef* cast_node = CHECK_NOTNULL(node_map.GetNode("Cast_new")); const NodeDef* weights_node = - CHECK_NOTNULL(node_map.GetNode("weights_scaled")); + CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D")); const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); EXPECT_EQ(output.node_size(), 7); @@ -897,6 +897,50 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { EXPECT_EQ(conv_node->input(1), weights_node->name()); } +TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { + // This unit test exercises optimization of folding mul into conv for + // multiple nodes in the graph. + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); + + GrapplerItem item; + Output conv[2]; + + for (int i = 0; i < 2; ++i) { + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28})); + Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f)); + Output weights = ops::Const(s.WithOpName("weights"), + Input::Initializer(127.0f, {5, 5, 3, 16})); + conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID", + ops::Conv2D::DataFormat("NCHW")); + } + Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]); + + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + + item.graph = output; + TF_EXPECT_OK( + ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); + + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + NodeMap node_map(&output); + const NodeDef* weights_node = + CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D")); + const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); + + const NodeDef* weights_node_1 = + CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D_1")); + const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1")); + EXPECT_EQ(conv_node->input(1), weights_node->name()); + EXPECT_EQ(conv_node_1->input(1), weights_node_1->name()); +} + TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = |