aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-10 10:22:07 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit51889acee1a266478b578afad3fbe7b3a90fc17a (patch)
treeedbbb81c4f685dee7fd51ab7c40576b184849c89 /tensorflow
parent10d1827987b0eca4d0e6f8f56506c93c67e03f83 (diff)
Add suffix to newly created Mul op in the optimizer to avoid the name collision
when two Conv2D objects depend on the same Const. PiperOrigin-RevId: 175305425
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc46
2 files changed, 47 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index f2277a9b79..e8ef0e94b5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -794,7 +794,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
scale_tensor.tensor_shape().dim_size() == 0) {
// Create new node `scaled_weights`.
NodeDef* scaled_weights = graph_def->add_node();
- scaled_weights->set_name(weights->name() + "_scaled");
+ scaled_weights->set_name(weights->name() + "_scaled_" +
+ conv->name());
scaled_weights->set_op("Mul");
scaled_weights->set_device(weights->device());
(*scaled_weights->mutable_attr())["T"] =
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 60fb47f51a..4fcbb0120e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -887,7 +887,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
CHECK_NOTNULL(node_map.GetNode("Transpose_uint8"));
const NodeDef* cast_node = CHECK_NOTNULL(node_map.GetNode("Cast_new"));
const NodeDef* weights_node =
- CHECK_NOTNULL(node_map.GetNode("weights_scaled"));
+ CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D"));
const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
EXPECT_EQ(output.node_size(), 7);
@@ -897,6 +897,50 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
EXPECT_EQ(conv_node->input(1), weights_node->name());
}
+TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
+ // This unit test exercises optimization of folding mul into conv for
+ // multiple nodes in the graph.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
+
+ GrapplerItem item;
+ Output conv[2];
+
+ for (int i = 0; i < 2; ++i) {
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
+ Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
+ Output weights = ops::Const(s.WithOpName("weights"),
+ Input::Initializer(127.0f, {5, 5, 3, 16}));
+ conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
+ ops::Conv2D::DataFormat("NCHW"));
+ }
+ Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
+
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+
+ item.graph = output;
+ TF_EXPECT_OK(
+ ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
+
+ item.graph = output;
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+ NodeMap node_map(&output);
+ const NodeDef* weights_node =
+ CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D"));
+ const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
+
+ const NodeDef* weights_node_1 =
+ CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D_1"));
+ const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1"));
+ EXPECT_EQ(conv_node->input(1), weights_node->name());
+ EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
+}
+
TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =