diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py')
-rw-r--r-- | tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 184144f64a..f17dbb0fe3 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -145,24 +145,20 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - mask = array_ops.sequence_mask( - sequence_length, maxlen=max_time, dtype=dtypes.int32) - mask = array_ops.transpose(mask, perm=[2, 0, 1]) - - # Use beam_width + 1 to mark the end of beam. - masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) - max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( - step_ids=masked_beam_ids, + step_ids=beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. + in_bound_steps = array_ops.transpose( + array_ops.sequence_mask(sequence_length, maxlen=max_time), + perm=[2, 0, 1]) sorted_beam_ids = array_ops.where( - math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) + in_bound_steps, x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile(array_ops.reshape( @@ -250,7 +246,7 @@ class BeamSearchDecoder(decoder.Decoder): ``` tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) |