diff options
author | 2018-07-10 11:41:13 +0200 | |
---|---|---|
committer | 2018-07-10 11:41:13 +0200 | |
commit | 5ed41f6da78bf79ad1ec75918d5e7b315ad4bec8 (patch) | |
tree | 9a15817078e58747da9037b71dabc5d8893d14b5 /tensorflow/contrib/seq2seq | |
parent | 10aacf16954764581def538dbea975c8946c9ea7 (diff) |
Fix masking of beam ids in gather_tree_from_array
The `sequence_length` argument that is passed to the function is the
lengths of the **reordered** predictions and was incorrectly used to
mask beam ids *before* reordering. Instead, we can reorder beam ids
without caring about out of range steps and only select the reodered
ids that are in bounds.
The added test covers a beam trajectory that previously produced an
out of range error because `gather_tree` returned `end_token` (here
`beam_width + 1`) for some steps.
Diffstat (limited to 'tensorflow/contrib/seq2seq')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py | 42 | ||||
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 14 |
2 files changed, 47 insertions, 9 deletions
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 178328619f..4073b390fc 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -132,6 +132,48 @@ class TestGatherTree(test.TestCase): def test_gather_tree_from_array_2d(self): self._test_gather_tree_from_array(depth_ndims=2) + def test_gather_tree_from_array_complex_trajectory(self): + # Max. time = 7, batch = 1, beam = 5. + array = np.expand_dims(np.array( + [[[25, 12, 114, 89, 97]], + [[9, 91, 64, 11, 162]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 2, 4]], + [[2, 3, 6, 2, 2]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + parent_ids = np.array( + [[[0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 4]], + [[0, 0, 1, 2, 1]], + [[0, 1, 1, 2, 3]], + [[0, 1, 3, 1, 2]], + [[0, 1, 2, 3, 4]]]) + expected_array = np.expand_dims(np.array( + [[[25, 25, 25, 25, 25]], + [[9, 9, 91, 9, 9]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 4, 4]], + [[2, 3, 6, 3, 6]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + sequence_length = [[4, 6, 4, 7, 6]] + + array = ops.convert_to_tensor( + array, dtype=dtypes.float32) + parent_ids = ops.convert_to_tensor( + parent_ids, dtype=dtypes.int32) + expected_array = ops.convert_to_tensor( + expected_array, dtype=dtypes.float32) + + sorted_array = beam_search_decoder.gather_tree_from_array( + array, parent_ids, sequence_length) + + with self.test_session() as sess: + sorted_array, expected_array = sess.run([sorted_array, expected_array]) + self.assertAllEqual(expected_array, sorted_array) + class TestArrayShapeChecks(test.TestCase): diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index c7fbeea310..f17dbb0fe3 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -145,24 +145,20 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - mask = array_ops.sequence_mask( - sequence_length, maxlen=max_time, dtype=dtypes.int32) - mask = array_ops.transpose(mask, perm=[2, 0, 1]) - - # Use beam_width + 1 to mark the end of beam. - masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) - max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( - step_ids=masked_beam_ids, + step_ids=beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. + in_bound_steps = array_ops.transpose( + array_ops.sequence_mask(sequence_length, maxlen=max_time), + perm=[2, 0, 1]) sorted_beam_ids = array_ops.where( - math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) + in_bound_steps, x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile(array_ops.reshape( |