aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/ndlstm/python/lstm2d_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/ndlstm/python/lstm2d_test.py')
-rw-r--r--tensorflow/contrib/ndlstm/python/lstm2d_test.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/ndlstm/python/lstm2d_test.py b/tensorflow/contrib/ndlstm/python/lstm2d_test.py
index 3dbbb81796..f1b37d701b 100644
--- a/tensorflow/contrib/ndlstm/python/lstm2d_test.py
+++ b/tensorflow/contrib/ndlstm/python/lstm2d_test.py
@@ -69,6 +69,14 @@ class Lstm2DTest(test_util.TensorFlowTestCase):
result = outputs.eval()
self.assertEqual(tuple(result.shape), (2, 7, 11, 8))
+ def testSeparableLstmDimsBlocks(self):
+ with self.test_session():
+ inputs = constant_op.constant(_rand(2, 7, 11, 5))
+ outputs = lstm2d.separable_lstm(inputs, 8, kernel_size=[2, 2])
+ variables.global_variables_initializer().run()
+ result = outputs.eval()
+ self.assertEqual(tuple(result.shape), (2, 4, 6, 8))
+
def testReduceToSequenceDims(self):
with self.test_session():
inputs = constant_op.constant(_rand(2, 7, 11, 5))