aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-02-10 01:45:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-10 01:48:49 -0800
commit69fbb7717102dba11c471cd77088a5d6c1274b71 (patch)
tree2acbccb3413e772c4da4c8f4be196d24dbbbd1b1 /tensorflow/python/grappler
parent7ace7f14caf81c9acbac2e3ba26a754cbe78ead5 (diff)
Do not convert layout for Select if condition input is of unknown shape.
PiperOrigin-RevId: 185242138
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 25b1cdcbc5..30dcdf31aa 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -771,6 +771,37 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_nchw_to_nhwc('Select-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testSelectOpConditionUnknownShape(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ add = math_ops.add(conv, conv)
+ condition = array_ops.placeholder(dtype='bool')
+ select = gen_math_ops._select(condition, conv, add)
+ output = array_ops.identity(select)
+
+ condition_val = np.zeros((1, 7, 7, 64))
+ with session.Session() as sess:
+ output_val_ref = sess.run(output, feed_dict={condition: condition_val})
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(
+ output, run_metadata=metadata, feed_dict={condition: condition_val})
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if _is_transpose(node.name):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ expected_num_transposes = 3
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testSelectOpScalarCondition(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)