aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-26 12:50:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-26 14:01:27 -0700
commitb0bdff4827f867a67f572ed99d85f9a847788326 (patch)
tree750f4a99478f8bd5ad1c99abd1af0179647e8977 /tensorflow/python/kernel_tests/reverse_sequence_op_test.py
parent0f867ebf83cd594aeb73b1577044c91a749ddac5 (diff)
Merge changes from github.
Change: 131437429
Diffstat (limited to 'tensorflow/python/kernel_tests/reverse_sequence_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index 48fbccd686..cea9711d32 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -47,7 +47,7 @@ class ReverseSequenceTest(tf.test.TestCase):
self._testReverseSequence(x, batch_dim, seq_dim, seq_lengths,
truth, False, expected_err_re)
- def _testBasic(self, dtype):
+ def _testBasic(self, dtype, len_dtype=np.int64):
x = np.asarray([
[[1, 2, 3, 4], [5, 6, 7, 8]],
[[9, 10, 11, 12], [13, 14, 15, 16]],
@@ -56,7 +56,7 @@ class ReverseSequenceTest(tf.test.TestCase):
x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
- seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
+ seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
truth_orig = np.asarray(
[[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
@@ -70,6 +70,9 @@ class ReverseSequenceTest(tf.test.TestCase):
batch_dim = 2
self._testBothReverseSequence(x, batch_dim, seq_dim, seq_lengths, truth)
+ def testSeqLenghtInt32(self):
+ self._testBasic(np.float32, np.int32)
+
def testFloatBasic(self):
self._testBasic(np.float32)