diff options
author | 2018-01-10 16:13:09 -0800 | |
---|---|---|
committer | 2018-01-10 16:16:18 -0800 | |
commit | 4b277703f7ce0f8f0e63bbadd1cb9dd0a8cb1181 (patch) | |
tree | daefe9e923f2a4fa1c2172e786622c30db4e0d36 | |
parent | b977ebeca140189b02ba003f9a456d86c34459cd (diff) |
Fix a bug in updating NodeMap, where the node name without port number should
have been used.
PiperOrigin-RevId: 181532901
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 9 | ||||
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 11 |
2 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 870b5289b5..ea7b05d381 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -625,7 +625,8 @@ class NodeProcessor : public GraphProcessor { node_name, node_->input(pos), const_name, dtype, input_node->attr().at("_output_shapes").list().shape(output_pos), true); - node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name); + node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(), + node_name); node_map_->AddOutput(node_name, node_->name()); *node_->mutable_input(pos) = node_name; } @@ -917,7 +918,7 @@ class NodeProcessor : public GraphProcessor { auto added_node = AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true); *node_->mutable_input(input_pos) = added_node->name(); - node_map_->UpdateOutput(added_node->input(0), node_->name(), + node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(), added_node->name()); node_map_->AddOutput(added_node->name(), node_->name()); } @@ -1328,8 +1329,8 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { AddNodeReshape(reshape_node_name, node_->input(vector_index), shape_const_node_name, node_->attr().at("T").type()); node_map_->AddOutput(shape_const_node_name, reshape_node_name); - node_map_->UpdateOutput(node_->input(vector_index), node_->name(), - reshape_node_name); + node_map_->UpdateOutput(NodeName(node_->input(vector_index)), + node_->name(), reshape_node_name); node_map_->AddOutput(reshape_node_name, node_->name()); *node_->mutable_input(vector_index) = reshape_node_name; } diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 4fdd779ddd..25c5ef6b68 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -283,7 +283,12 @@ class LayoutOptimizerTest(test.TestCase): conv = _two_layer_model(x) dim = array_ops.placeholder(dtype='int32') split = array_ops.split(conv, 2, axis=dim) - output = math_ops.reduce_sum(split[0]) + scale = constant_op.constant(0.1, shape=[32]) + offset = constant_op.constant(0.3, shape=[32]) + bn0 = nn.fused_batch_norm(split[0], scale, offset) + bn1 = nn.fused_batch_norm(split[1], scale, offset) + add = bn0[0] + bn1[0] + output = array_ops.identity(add) with session.Session() as sess: output_val_ref = sess.run(output, feed_dict={dim: 3}) @@ -299,12 +304,10 @@ class LayoutOptimizerTest(test.TestCase): num_transposes += 1 nodes.append(node.name) - # Four transposes were initially added in the Expand phase of - # LayoutOptimizer; two of them are cancelled out in the Collapse phase. expected_num_transposes = 2 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('split-0-0', nodes) + self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes) self._assert_map_nhwc_to_nchw('split-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) |