diff options
author | Yao Zhang <yaozhang@google.com> | 2018-02-10 01:45:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-10 01:48:49 -0800 |
commit | 69fbb7717102dba11c471cd77088a5d6c1274b71 (patch) | |
tree | 2acbccb3413e772c4da4c8f4be196d24dbbbd1b1 /tensorflow/python/grappler | |
parent | 7ace7f14caf81c9acbac2e3ba26a754cbe78ead5 (diff) |
Do not convert layout for Select if condition input is of unknown shape.
PiperOrigin-RevId: 185242138
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 25b1cdcbc5..30dcdf31aa 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -771,6 +771,37 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nchw_to_nhwc('Select-0-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testSelectOpConditionUnknownShape(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + add = math_ops.add(conv, conv) + condition = array_ops.placeholder(dtype='bool') + select = gen_math_ops._select(condition, conv, add) + output = array_ops.identity(select) + + condition_val = np.zeros((1, 7, 7, 64)) + with session.Session() as sess: + output_val_ref = sess.run(output, feed_dict={condition: condition_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={condition: condition_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 3 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testSelectOpScalarCondition(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) |