diff options
author | 2017-10-17 08:48:29 -0700 | |
---|---|---|
committer | 2017-10-17 08:56:25 -0700 | |
commit | 18f89c81d288f191abd56501ec6f86fe29265bdd (patch) | |
tree | 2f7a4d78af10e9cdfa7b4badf3bd91b4aca1b9ed | |
parent | a1ba9f3bf16cb53b8468b93021611311a9be55b4 (diff) |
[tf.contrib.seq2seq] Bugfixes to BeamSearchDecoder and GatherTree.
1. Begin the gather tree at the maximum sequence length across all beams (within the batch).
2. Take a second pass starting from t=0 and mask out any beam ids past the *first* beam occurrence of end_token.
3. Update the final sequence lengths to include the first <eos> token in the beam.
4. Update dynamic_decode to allow the BeamSearchDecoder to keep track of its own "finished" states, as the shuffling in the decoder confused the tracking mechanism in dynamic_decode. This fixes a bug where beam search decoding stops early.
5. Cap sequence length used in GatherTree to min(max_time, max_seq_len(b)) to avoid accessing memory outside the dimensions of input matrices.
Bugs caught by @bdaskalov on github and Pavel Sountsov. Proper solution and analysis thanks to Rui Zhao. Thanks all!
Fixes #13536.
PiperOrigin-RevId: 172471462
8 files changed, 217 insertions, 147 deletions
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index aab0f3f494..95273e2b33 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel { const Device& device = ctx->eigen_device<Device>(); const Tensor& step_ids = ctx->input(0); const Tensor& parent_ids = ctx->input(1); - const Tensor& sequence_length = ctx->input(2); + const Tensor& max_sequence_lengths = ctx->input(2); + const Tensor& end_token = ctx->input(3); const TensorShape& step_ids_shape = step_ids.shape(); OP_REQUIRES( ctx, step_ids_shape.dims() == 3, errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ", step_ids_shape.DebugString())); - OP_REQUIRES( - ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()), - errors::InvalidArgument("sequence_length must be a matrix, saw shape: ", - sequence_length.shape().DebugString())); - OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1), - errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[0] (", - sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (", - step_ids_shape.dim_size(1), ")")); - OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()), errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[1] (", - sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (", - step_ids_shape.dim_size(2), ")")); + "max_sequence_lengths must be a vector, saw shape: ", + max_sequence_lengths.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(end_token.shape()), + errors::InvalidArgument("end_token must be a scalar, saw shape: ", + end_token.shape().DebugString())); OP_REQUIRES( ctx, step_ids_shape == parent_ids.shape(), errors::InvalidArgument( "step_ids.shape must match parent_ids.shape. but shapes are: ", step_ids_shape.DebugString(), " and ", parent_ids.shape().DebugString())); + OP_REQUIRES( + ctx, + step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0), + errors::InvalidArgument("batch size dimensions step_ids.shape[1] and " + "max_seqeuence_lengths.shape[0] must match. " + "but shapes are: ", + step_ids_shape.DebugString(), " and ", + max_sequence_lengths.shape().DebugString())); Tensor* beams; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>(); typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>(); - typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>(); + typename TTypes<int32>::ConstVec max_seq_lens_t = + max_sequence_lengths.vec<int32>(); + typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>(); typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>(); + const T end_token_value = end_token_t(); functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t, - seq_len_t, beams_t); + max_seq_lens_t, end_token_value, beams_t); } }; @@ -99,27 +105,29 @@ namespace functor { template <> struct GatherTree<CPUDevice, int32> { void operator()(OpKernelContext* ctx, const CPUDevice& d, - typename TTypes<int32, 3>::ConstTensor step_ids, - typename TTypes<int32, 3>::ConstTensor parent_ids, - typename TTypes<int32>::ConstMatrix sequence_length, - typename TTypes<int32, 3>::Tensor beams) { - const int64 max_time = parent_ids.dimension(0); - const int64 batch_size = parent_ids.dimension(1); - const int64 beam_width = parent_ids.dimension(2); + TTypes<int32, 3>::ConstTensor step_ids, + TTypes<int32, 3>::ConstTensor parent_ids, + TTypes<int32>::ConstVec max_sequence_lengths, + const int32 end_token, TTypes<int32, 3>::Tensor beams) { + const int32 max_time = parent_ids.dimension(0); + const int32 batch_size = parent_ids.dimension(1); + const int32 beam_width = parent_ids.dimension(2); beams.setConstant(-1); - auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) { + auto DoWork = [&, ctx, end_token](int start_batch_beam, + int limit_batch_beam) { for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - int32 seq_len_b = sequence_length(batch, beam); - if (seq_len_b <= 0) { + const int32 max_seq_len_b = + Eigen::numext::mini(max_time, max_sequence_lengths(batch)); + if (max_seq_len_b <= 0) { continue; } - beams(seq_len_b - 1, batch, beam) = - step_ids(seq_len_b - 1, batch, beam); - int32 parent = parent_ids(seq_len_b - 1, batch, beam); - for (int32 level = seq_len_b - 2; level >= 0; --level) { + beams(max_seq_len_b - 1, batch, beam) = + step_ids(max_seq_len_b - 1, batch, beam); + int32 parent = parent_ids(max_seq_len_b - 1, batch, beam); + for (int32 level = max_seq_len_b - 2; level >= 0; --level) { if (parent < 0 || parent > beam_width) { ctx->SetStatus( errors::InvalidArgument("Saw invalid parent id ", parent, @@ -130,6 +138,14 @@ struct GatherTree<CPUDevice, int32> { beams(level, batch, beam) = step_ids(level, batch, parent); parent = parent_ids(level, batch, parent); } + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + if (finished) { + beams(time, batch, beam) = -1; + } else if (beams(time, batch, beam) == end_token) { + finished = true; + } + } } }; // Guesstimate of cost; ~5 lookup/store/compare per inner beam @@ -137,7 +153,7 @@ struct GatherTree<CPUDevice, int32> { const int64 batch_beam_cost = Eigen::TensorOpCost::DivCost<int32>() + 6 * Eigen::TensorOpCost::AddCost<int32>() + - max_time * (5 * Eigen::TensorOpCost::AddCost<int32>()); + 2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>()); auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, batch_size * beam_width, batch_beam_cost, DoWork); @@ -148,24 +164,26 @@ struct GatherTree<CPUDevice, int32> { #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void GatherTree<GPUDevice, T>::operator()( \ - OpKernelContext* ctx, const GPUDevice& d, \ - typename TTypes<T, 3>::ConstTensor step_ids, \ - typename TTypes<T, 3>::ConstTensor parent_ids, \ - typename TTypes<T>::ConstMatrix sequence_length, \ - typename TTypes<T, 3>::Tensor beams); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void GatherTree<GPUDevice, T>::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, \ + typename TTypes<T, 3>::ConstTensor step_ids, \ + typename TTypes<T, 3>::ConstTensor parent_ids, \ + TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \ + typename TTypes<T, 3>::Tensor beams); \ extern template struct GatherTree<GPUDevice, T>; DECLARE_GPU_SPEC(int32); #undef DECLARE_GPU_SPEC } // end namespace functor -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ - GatherTreeOp<GPUDevice, T>); +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("GatherTree") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("end_token"), \ + GatherTreeOp<GPUDevice, T>); REGISTER_GPU_KERNEL(int32); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h index 124d07264e..693b02dc43 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -31,8 +31,8 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const Device& d, typename TTypes<T, 3>::ConstTensor step_ids, typename TTypes<T, 3>::ConstTensor parent_ids, - typename TTypes<T>::ConstMatrix sequence_length, - typename TTypes<T, 3>::Tensor beams); + TTypes<int32>::ConstVec max_sequence_lengths, + const T end_token, typename TTypes<T, 3>::Tensor beams); }; } // namespace functor diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index ee68b55d20..e71efc48ce 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -29,20 +29,24 @@ template <typename T> __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, const int32 beam_width, const T* step_ids, const T* parent_ids, - const T* sequence_length, T* beams) { + const int32* max_sequence_lengths, + const T end_token, T* beams) { CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam); - if (seq_len_b <= 0) continue; + const int32 max_seq_len_b = + Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch)); + if (max_seq_len_b <= 0) { + continue; + } #define GET_IX(time_ix, beam_ix) \ (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix)) - const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam); + const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix); - for (int32 level = seq_len_b - 2; level >= 0; --level) { + for (int32 level = max_seq_len_b - 2; level >= 0; --level) { const int32 level_beam_ix = GET_IX(level, beam); const int32 level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > beam_width) { @@ -53,6 +57,15 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, parent = ldg(parent_ids + level_parent_ix); } } + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + const int32 level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = -1; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } + } #undef GET_IX } } @@ -62,8 +75,8 @@ struct GatherTree<GPUDevice, T> { void operator()(OpKernelContext* ctx, const GPUDevice& d, typename TTypes<T, 3>::ConstTensor step_ids, typename TTypes<T, 3>::ConstTensor parent_ids, - typename TTypes<T>::ConstMatrix sequence_length, - typename TTypes<T, 3>::Tensor beams) { + TTypes<int32>::ConstVec max_sequence_length, + const T end_token, typename TTypes<T, 3>::Tensor beams) { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); @@ -75,7 +88,10 @@ struct GatherTree<GPUDevice, T> { GatherTreeOpKernel<T> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( batch_size, max_time, beam_width, - step_ids.data(), parent_ids.data(), sequence_length.data(), + step_ids.data(), + parent_ids.data(), + max_sequence_length.data(), + end_token, beams.data()); // clang-format on } diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index 6c445cd460..231504bfbb 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -25,27 +25,27 @@ using shape_inference::ShapeHandle; REGISTER_OP("GatherTree") .Input("step_ids: T") .Input("parent_ids: T") - .Input("sequence_length: T") + .Input("max_sequence_lengths: int32") + .Input("end_token: T") .Output("beams: T") .Attr("T: {int32}") .SetShapeFn([](InferenceContext* c) { - ShapeHandle step_ids, parent_ids, sequence_length; + ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token; // step_ids, parent_ids, and output are all shaped: // [max_time, batch_size, beam_width]. - // sequence_length is shaped [batch_size, beam_width]. + // max_sequence_length is shaped [batch_size] and end_token is a scalar. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length)); - - DimensionHandle batch_size = c->Dim(step_ids, 1); - DimensionHandle beam_width = c->Dim(step_ids, 2); - + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token)); TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids)); + DimensionHandle batch_size = c->Dim(step_ids, 1); TF_RETURN_IF_ERROR( - c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size)); - TF_RETURN_IF_ERROR( - c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width)); + c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size)); + ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size); + TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids, + &step_ids_prefix)); c->set_output(0, step_ids); return tensorflow::Status::OK(); @@ -61,7 +61,8 @@ TODO(ebrevdo): fill in step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. -sequence_length: `[batch_size, beam_width]`. +max_sequence_lengths: `[batch_size]`. +end_token: `[]`. beams: `[max_time, batch_size, beam_width]`. )doc"); diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 8d4ec4b4db..d2beac5f31 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -54,15 +54,18 @@ class TestGatherTree(test.TestCase): [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], dtype=np.int32).transpose([1, 0, 2]) - # sequence_lengths is shaped (batch_size = 2, beam_width = 3) - sequence_lengths = [[3, 3, 3], [3, 3, 3]] + # sequence_lengths is shaped (batch_size = 3) + max_sequence_lengths = [3, 3] expected_result = np.array( [[[2, 2, 2], [6, 5, 6], [7, 8, 9]], [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2]) res = beam_search_ops.gather_tree( - predicted_ids, parent_ids, sequence_lengths) + predicted_ids, + parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=11) with self.test_session() as sess: res_ = sess.run(res) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index 50cccf392f..f301314872 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import +import itertools + import numpy as np from tensorflow.contrib.seq2seq.python.ops import beam_search_ops @@ -38,12 +40,14 @@ class GatherTreeTest(test.TestCase): [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] + max_sequence_lengths = [3] expected_result = _transpose_batch_time( [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=10) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) @@ -54,11 +58,13 @@ class GatherTreeTest(test.TestCase): [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] + max_sequence_lengths = [3] with ops.device("/cpu:0"): beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=10) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): @@ -75,78 +81,58 @@ class GatherTreeTest(test.TestCase): [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] + max_sequence_lengths = [3] expected_result = _transpose_batch_time( [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=10) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testGatherTreeBatch(self): - # sequence_length is [batch_size, beam_width] = [4, 5] - sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5] + batch_size = 10 + beam_width = 15 + max_time = 8 + max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] + end_token = 5 with self.test_session(use_gpu=True): - # (max_time = 4, batch_size = 4, beam_width = 5) - step_ids = _transpose_batch_time( - [[[3, 4, 0, 4, 0], - [4, 2, 0, 3, 1], - [1, 1, 3, 2, 2], - [3, 1, 2, 3, 4]], - [[3, 4, 0, 4, 0], - [4, 2, 0, 3, 1], - [1, 1, 3, 2, 2], - [3, 1, 2, 3, 4]], - [[1, 2, 3, 4, 2], - [2, 1, 1, 3, 2], - [3, 0, 1, 0, 0], - [3, 4, 0, 2, 4]], - [[0, 2, 2, 3, 1], - [3, 2, 2, 2, 3], - [3, 4, 3, 0, 3], - [1, 2, 2, 2, 4]]]) - parent_ids = _transpose_batch_time( - [[[4, 2, 4, 3, 4], - [3, 4, 0, 2, 0], - [3, 1, 3, 2, 2], - [0, 2, 1, 4, 2]], - [[4, 2, 4, 3, 4], - [3, 4, 0, 2, 0], - [3, 1, 3, 2, 2], - [0, 2, 1, 4, 2]], - [[3, 0, 0, 4, 0], - [1, 2, 4, 2, 2], - [4, 4, 0, 3, 0], - [2, 4, 4, 3, 0]], - [[3, 1, 4, 1, 3], - [3, 2, 4, 0, 4], - [1, 0, 1, 4, 2], - [0, 3, 2, 0, 1]]]) - expected_beams = _transpose_batch_time( - [[[-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[3, 4, 0, 4, 0], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[2, 3, 2, 3, 3], - [2, 1, 1, 3, 2], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[2, 3, 2, 1, 1], - [2, 3, 2, 3, 2], - [3, 4, 3, 0, 3], - [-1, -1, -1, -1, -1]]]) + step_ids = np.random.randint( + 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) + parent_ids = np.random.randint( + 0, high=beam_width - 1, size=(max_time, batch_size, beam_width)) beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) - self.assertAllEqual(expected_beams, beams.eval()) + step_ids=step_ids.astype(np.int32), + parent_ids=parent_ids.astype(np.int32), + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) + + self.assertEqual((max_time, batch_size, beam_width), beams.shape) + beams_value = beams.eval() + for b in range(batch_size): + # Past max_sequence_lengths[b], we emit all -1s. + b_value = beams_value[max_sequence_lengths[b]:, b, :] + self.assertAllClose(b_value, -1. * np.ones_like(b_value)) + for batch, beam in itertools.product( + range(batch_size), range(beam_width)): + v = np.squeeze(beams_value[:, batch, beam]) + if end_token in v: + found = np.where(v == end_token)[0] + # Should be up to 1 instance of end_token per beam. + self.assertEqual(len(found), 1) + found = found[0] + # If an end_token is found, everything before it should be a + # valid id and everything after it should be -1. + if found > 0: + self.assertAllEqual( + v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) + self.assertAllClose( + v[found + 1:], -1 * np.ones_like(v[found + 1:])) if __name__ == "__main__": diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 112ac57a1b..a88d4f5b8b 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -254,6 +254,20 @@ class BeamSearchDecoder(decoder.Decoder): return nest.map_structure(lambda s: s[1:], layer_output_shape) @property + def tracks_own_finished(self): + """The BeamSearchDecoder shuffles its beams and their finished state. + + For this reason, it conflicts with the `dynamic_decode` function's + tracking of finished states. Setting this property to true avoids + early stopping of decoding due to mismanagement of the finished state + in `dynamic_decode`. + + Returns: + `True`. + """ + return True + + @property def output_size(self): # Return the cell output and the id return BeamSearchDecoderOutput( @@ -303,15 +317,23 @@ class BeamSearchDecoder(decoder.Decoder): output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. + **NOTE** These are ignored; the updated sequence lengths are stored in + `final_state.lengths`. Returns: - outputs: An instance of FinalBeamSearchDecoderOutput where the + outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. - final_state: The same input instance of BeamSearchDecoderState. + final_state: The same input instance of `BeamSearchDecoderState`. """ + del sequence_lengths + # Get max_sequence_length across all beams for each batch. + max_sequence_lengths = math_ops.to_int32( + math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( - outputs.predicted_ids, outputs.parent_ids, - sequence_length=sequence_lengths) + outputs.predicted_ids, + outputs.parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=self._end_token) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state @@ -588,10 +610,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, name="next_beam_finished") # Calculate the length of the next predictions. - # 1. Finished beams remain unchanged - # 2. Beams that are now finished (EOS predicted) remain unchanged - # 3. Beams that are not yet finished have their length increased by 1 - lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished)) + # 1. Finished beams remain unchanged. + # 2. Beams that are now finished (EOS predicted) have their length + # increased by 1. + # 3. Beams that are not yet finished have their length increased by 1. + lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index fbe53fc60a..f14974b9d5 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -100,16 +100,36 @@ class Decoder(object): Returns: `(outputs, next_state, next_inputs, finished)`: `outputs` is an object - containing the decoder output, `next_state` is a (structure of) state tensors - and TensorArrays, `next_inputs` is the tensor that should be used as input for - the next step, `finished` is a boolean tensor telling whether the sequence - is complete, for each sequence in the batch. + containing the decoder output, `next_state` is a (structure of) state + tensors and TensorArrays, `next_inputs` is the tensor that should be used + as input for the next step, `finished` is a boolean tensor telling whether + the sequence is complete, for each sequence in the batch. """ raise NotImplementedError def finalize(self, outputs, final_state, sequence_lengths): raise NotImplementedError + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `dynamic_decode` function keeps track + of which batch entries are already finished, and performs a logical OR to + insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `dynamic_decode` will mix up the finished state across these entries because + it does not track the reshuffle across time steps. In this case, it is + up to the decoder to declare that it will keep track of its own finished + state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" @@ -232,7 +252,10 @@ def dynamic_decode(decoder, """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) - next_finished = math_ops.logical_or(decoder_finished, finished) + if decoder.tracks_own_finished: + next_finished = decoder_finished + else: + next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) |