diff options
Diffstat (limited to 'tensorflow/contrib/ndlstm/python/lstm2d_test.py')
-rw-r--r-- | tensorflow/contrib/ndlstm/python/lstm2d_test.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/ndlstm/python/lstm2d_test.py b/tensorflow/contrib/ndlstm/python/lstm2d_test.py index 3dbbb81796..f1b37d701b 100644 --- a/tensorflow/contrib/ndlstm/python/lstm2d_test.py +++ b/tensorflow/contrib/ndlstm/python/lstm2d_test.py @@ -69,6 +69,14 @@ class Lstm2DTest(test_util.TensorFlowTestCase): result = outputs.eval() self.assertEqual(tuple(result.shape), (2, 7, 11, 8)) + def testSeparableLstmDimsBlocks(self): + with self.test_session(): + inputs = constant_op.constant(_rand(2, 7, 11, 5)) + outputs = lstm2d.separable_lstm(inputs, 8, kernel_size=[2, 2]) + variables.global_variables_initializer().run() + result = outputs.eval() + self.assertEqual(tuple(result.shape), (2, 4, 6, 8)) + def testReduceToSequenceDims(self): with self.test_session(): inputs = constant_op.constant(_rand(2, 7, 11, 5)) |