aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-01-10 16:13:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 16:16:18 -0800
commit4b277703f7ce0f8f0e63bbadd1cb9dd0a8cb1181 (patch)
treedaefe9e923f2a4fa1c2172e786622c30db4e0d36
parentb977ebeca140189b02ba003f9a456d86c34459cd (diff)
Fix a bug in updating NodeMap, where the node name without port number should
have been used. PiperOrigin-RevId: 181532901
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc9
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py11
2 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 870b5289b5..ea7b05d381 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -625,7 +625,8 @@ class NodeProcessor : public GraphProcessor {
node_name, node_->input(pos), const_name, dtype,
input_node->attr().at("_output_shapes").list().shape(output_pos),
true);
- node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
+ node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
+ node_name);
node_map_->AddOutput(node_name, node_->name());
*node_->mutable_input(pos) = node_name;
}
@@ -917,7 +918,7 @@ class NodeProcessor : public GraphProcessor {
auto added_node =
AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
*node_->mutable_input(input_pos) = added_node->name();
- node_map_->UpdateOutput(added_node->input(0), node_->name(),
+ node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
added_node->name());
node_map_->AddOutput(added_node->name(), node_->name());
}
@@ -1328,8 +1329,8 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
AddNodeReshape(reshape_node_name, node_->input(vector_index),
shape_const_node_name, node_->attr().at("T").type());
node_map_->AddOutput(shape_const_node_name, reshape_node_name);
- node_map_->UpdateOutput(node_->input(vector_index), node_->name(),
- reshape_node_name);
+ node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
+ node_->name(), reshape_node_name);
node_map_->AddOutput(reshape_node_name, node_->name());
*node_->mutable_input(vector_index) = reshape_node_name;
}
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 4fdd779ddd..25c5ef6b68 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -283,7 +283,12 @@ class LayoutOptimizerTest(test.TestCase):
conv = _two_layer_model(x)
dim = array_ops.placeholder(dtype='int32')
split = array_ops.split(conv, 2, axis=dim)
- output = math_ops.reduce_sum(split[0])
+ scale = constant_op.constant(0.1, shape=[32])
+ offset = constant_op.constant(0.3, shape=[32])
+ bn0 = nn.fused_batch_norm(split[0], scale, offset)
+ bn1 = nn.fused_batch_norm(split[1], scale, offset)
+ add = bn0[0] + bn1[0]
+ output = array_ops.identity(add)
with session.Session() as sess:
output_val_ref = sess.run(output, feed_dict={dim: 3})
@@ -299,12 +304,10 @@ class LayoutOptimizerTest(test.TestCase):
num_transposes += 1
nodes.append(node.name)
- # Four transposes were initially added in the Expand phase of
- # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
- self._assert_trans_nchw_to_nhwc('split-0-0', nodes)
+ self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes)
self._assert_map_nhwc_to_nchw('split-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)