aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py')
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 178328619f..4073b390fc 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -132,6 +132,48 @@ class TestGatherTree(test.TestCase):
def test_gather_tree_from_array_2d(self):
self._test_gather_tree_from_array(depth_ndims=2)
+ def test_gather_tree_from_array_complex_trajectory(self):
+ # Max. time = 7, batch = 1, beam = 5.
+ array = np.expand_dims(np.array(
+ [[[25, 12, 114, 89, 97]],
+ [[9, 91, 64, 11, 162]],
+ [[34, 34, 34, 34, 34]],
+ [[2, 4, 2, 2, 4]],
+ [[2, 3, 6, 2, 2]],
+ [[2, 2, 2, 3, 2]],
+ [[2, 2, 2, 2, 2]]]), -1)
+ parent_ids = np.array(
+ [[[0, 0, 0, 0, 0]],
+ [[0, 0, 0, 0, 0]],
+ [[0, 1, 2, 3, 4]],
+ [[0, 0, 1, 2, 1]],
+ [[0, 1, 1, 2, 3]],
+ [[0, 1, 3, 1, 2]],
+ [[0, 1, 2, 3, 4]]])
+ expected_array = np.expand_dims(np.array(
+ [[[25, 25, 25, 25, 25]],
+ [[9, 9, 91, 9, 9]],
+ [[34, 34, 34, 34, 34]],
+ [[2, 4, 2, 4, 4]],
+ [[2, 3, 6, 3, 6]],
+ [[2, 2, 2, 3, 2]],
+ [[2, 2, 2, 2, 2]]]), -1)
+ sequence_length = [[4, 6, 4, 7, 6]]
+
+ array = ops.convert_to_tensor(
+ array, dtype=dtypes.float32)
+ parent_ids = ops.convert_to_tensor(
+ parent_ids, dtype=dtypes.int32)
+ expected_array = ops.convert_to_tensor(
+ expected_array, dtype=dtypes.float32)
+
+ sorted_array = beam_search_decoder.gather_tree_from_array(
+ array, parent_ids, sequence_length)
+
+ with self.test_session() as sess:
+ sorted_array, expected_array = sess.run([sorted_array, expected_array])
+ self.assertAllEqual(expected_array, sorted_array)
+
class TestArrayShapeChecks(test.TestCase):