aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-10-17 08:48:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 08:56:25 -0700
commit18f89c81d288f191abd56501ec6f86fe29265bdd (patch)
tree2f7a4d78af10e9cdfa7b4badf3bd91b4aca1b9ed
parenta1ba9f3bf16cb53b8468b93021611311a9be55b4 (diff)
[tf.contrib.seq2seq] Bugfixes to BeamSearchDecoder and GatherTree.
1. Begin the gather tree at the maximum sequence length across all beams (within the batch). 2. Take a second pass starting from t=0 and mask out any beam ids past the *first* beam occurrence of end_token. 3. Update the final sequence lengths to include the first <eos> token in the beam. 4. Update dynamic_decode to allow the BeamSearchDecoder to keep track of its own "finished" states, as the shuffling in the decoder confused the tracking mechanism in dynamic_decode. This fixes a bug where beam search decoding stops early. 5. Cap sequence length used in GatherTree to min(max_time, max_seq_len(b)) to avoid accessing memory outside the dimensions of input matrices. Bugs caught by @bdaskalov on github and Pavel Sountsov. Proper solution and analysis thanks to Rui Zhao. Thanks all! Fixes #13536. PiperOrigin-RevId: 172471462
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc104
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.h4
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc32
-rw-r--r--tensorflow/contrib/seq2seq/ops/beam_search_ops.cc25
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py9
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py118
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py39
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py33
8 files changed, 217 insertions, 147 deletions
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
index aab0f3f494..95273e2b33 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
@@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel {
const Device& device = ctx->eigen_device<Device>();
const Tensor& step_ids = ctx->input(0);
const Tensor& parent_ids = ctx->input(1);
- const Tensor& sequence_length = ctx->input(2);
+ const Tensor& max_sequence_lengths = ctx->input(2);
+ const Tensor& end_token = ctx->input(3);
const TensorShape& step_ids_shape = step_ids.shape();
OP_REQUIRES(
ctx, step_ids_shape.dims() == 3,
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
step_ids_shape.DebugString()));
- OP_REQUIRES(
- ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
- errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
- sequence_length.shape().DebugString()));
- OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
- errors::InvalidArgument(
- "Inconsistent batch sizes: sequence_length.shape[0] (",
- sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
- step_ids_shape.dim_size(1), ")"));
- OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
errors::InvalidArgument(
- "Inconsistent batch sizes: sequence_length.shape[1] (",
- sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
- step_ids_shape.dim_size(2), ")"));
+ "max_sequence_lengths must be a vector, saw shape: ",
+ max_sequence_lengths.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(end_token.shape()),
+ errors::InvalidArgument("end_token must be a scalar, saw shape: ",
+ end_token.shape().DebugString()));
OP_REQUIRES(
ctx, step_ids_shape == parent_ids.shape(),
errors::InvalidArgument(
"step_ids.shape must match parent_ids.shape. but shapes are: ",
step_ids_shape.DebugString(), " and ",
parent_ids.shape().DebugString()));
+ OP_REQUIRES(
+ ctx,
+ step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
+ errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
+ "max_seqeuence_lengths.shape[0] must match. "
+ "but shapes are: ",
+ step_ids_shape.DebugString(), " and ",
+ max_sequence_lengths.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
- typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
+ typename TTypes<int32>::ConstVec max_seq_lens_t =
+ max_sequence_lengths.vec<int32>();
+ typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
+ const T end_token_value = end_token_t();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
- seq_len_t, beams_t);
+ max_seq_lens_t, end_token_value, beams_t);
}
};
@@ -99,27 +105,29 @@ namespace functor {
template <>
struct GatherTree<CPUDevice, int32> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
- typename TTypes<int32, 3>::ConstTensor step_ids,
- typename TTypes<int32, 3>::ConstTensor parent_ids,
- typename TTypes<int32>::ConstMatrix sequence_length,
- typename TTypes<int32, 3>::Tensor beams) {
- const int64 max_time = parent_ids.dimension(0);
- const int64 batch_size = parent_ids.dimension(1);
- const int64 beam_width = parent_ids.dimension(2);
+ TTypes<int32, 3>::ConstTensor step_ids,
+ TTypes<int32, 3>::ConstTensor parent_ids,
+ TTypes<int32>::ConstVec max_sequence_lengths,
+ const int32 end_token, TTypes<int32, 3>::Tensor beams) {
+ const int32 max_time = parent_ids.dimension(0);
+ const int32 batch_size = parent_ids.dimension(1);
+ const int32 beam_width = parent_ids.dimension(2);
beams.setConstant(-1);
- auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
+ auto DoWork = [&, ctx, end_token](int start_batch_beam,
+ int limit_batch_beam) {
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
- int32 seq_len_b = sequence_length(batch, beam);
- if (seq_len_b <= 0) {
+ const int32 max_seq_len_b =
+ Eigen::numext::mini(max_time, max_sequence_lengths(batch));
+ if (max_seq_len_b <= 0) {
continue;
}
- beams(seq_len_b - 1, batch, beam) =
- step_ids(seq_len_b - 1, batch, beam);
- int32 parent = parent_ids(seq_len_b - 1, batch, beam);
- for (int32 level = seq_len_b - 2; level >= 0; --level) {
+ beams(max_seq_len_b - 1, batch, beam) =
+ step_ids(max_seq_len_b - 1, batch, beam);
+ int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
+ for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
if (parent < 0 || parent > beam_width) {
ctx->SetStatus(
errors::InvalidArgument("Saw invalid parent id ", parent,
@@ -130,6 +138,14 @@ struct GatherTree<CPUDevice, int32> {
beams(level, batch, beam) = step_ids(level, batch, parent);
parent = parent_ids(level, batch, parent);
}
+ bool finished = false;
+ for (int32 time = 0; time < max_seq_len_b; ++time) {
+ if (finished) {
+ beams(time, batch, beam) = -1;
+ } else if (beams(time, batch, beam) == end_token) {
+ finished = true;
+ }
+ }
}
};
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
@@ -137,7 +153,7 @@ struct GatherTree<CPUDevice, int32> {
const int64 batch_beam_cost =
Eigen::TensorOpCost::DivCost<int32>() +
6 * Eigen::TensorOpCost::AddCost<int32>() +
- max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
+ 2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
batch_size * beam_width, batch_beam_cost, DoWork);
@@ -148,24 +164,26 @@ struct GatherTree<CPUDevice, int32> {
#if GOOGLE_CUDA
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void GatherTree<GPUDevice, T>::operator()( \
- OpKernelContext* ctx, const GPUDevice& d, \
- typename TTypes<T, 3>::ConstTensor step_ids, \
- typename TTypes<T, 3>::ConstTensor parent_ids, \
- typename TTypes<T>::ConstMatrix sequence_length, \
- typename TTypes<T, 3>::Tensor beams); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void GatherTree<GPUDevice, T>::operator()( \
+ OpKernelContext* ctx, const GPUDevice& d, \
+ typename TTypes<T, 3>::ConstTensor step_ids, \
+ typename TTypes<T, 3>::ConstTensor parent_ids, \
+ TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
+ typename TTypes<T, 3>::Tensor beams); \
extern template struct GatherTree<GPUDevice, T>;
DECLARE_GPU_SPEC(int32);
#undef DECLARE_GPU_SPEC
} // end namespace functor
-#define REGISTER_GPU_KERNEL(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("GatherTree").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
- GatherTreeOp<GPUDevice, T>);
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("GatherTree") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("end_token"), \
+ GatherTreeOp<GPUDevice, T>);
REGISTER_GPU_KERNEL(int32);
#undef REGISTER_GPU_KERNEL
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
index 124d07264e..693b02dc43 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
@@ -31,8 +31,8 @@ struct GatherTree {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
- typename TTypes<T>::ConstMatrix sequence_length,
- typename TTypes<T, 3>::Tensor beams);
+ TTypes<int32>::ConstVec max_sequence_lengths,
+ const T end_token, typename TTypes<T, 3>::Tensor beams);
};
} // namespace functor
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
index ee68b55d20..e71efc48ce 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
@@ -29,20 +29,24 @@ template <typename T>
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const int32 beam_width, const T* step_ids,
const T* parent_ids,
- const T* sequence_length, T* beams) {
+ const int32* max_sequence_lengths,
+ const T end_token, T* beams) {
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
- const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
- if (seq_len_b <= 0) continue;
+ const int32 max_seq_len_b =
+ Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch));
+ if (max_seq_len_b <= 0) {
+ continue;
+ }
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
- const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
+ const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
int32 parent = ldg(parent_ids + initial_beam_ix);
- for (int32 level = seq_len_b - 2; level >= 0; --level) {
+ for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
const int32 level_beam_ix = GET_IX(level, beam);
const int32 level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > beam_width) {
@@ -53,6 +57,15 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
parent = ldg(parent_ids + level_parent_ix);
}
}
+ bool finished = false;
+ for (int32 time = 0; time < max_seq_len_b; ++time) {
+ const int32 level_beam_ix = GET_IX(time, beam);
+ if (finished) {
+ beams[level_beam_ix] = -1;
+ } else if (beams[level_beam_ix] == end_token) {
+ finished = true;
+ }
+ }
#undef GET_IX
}
}
@@ -62,8 +75,8 @@ struct GatherTree<GPUDevice, T> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
- typename TTypes<T>::ConstMatrix sequence_length,
- typename TTypes<T, 3>::Tensor beams) {
+ TTypes<int32>::ConstVec max_sequence_length,
+ const T end_token, typename TTypes<T, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
@@ -75,7 +88,10 @@ struct GatherTree<GPUDevice, T> {
GatherTreeOpKernel<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
batch_size, max_time, beam_width,
- step_ids.data(), parent_ids.data(), sequence_length.data(),
+ step_ids.data(),
+ parent_ids.data(),
+ max_sequence_length.data(),
+ end_token,
beams.data());
// clang-format on
}
diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
index 6c445cd460..231504bfbb 100644
--- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
@@ -25,27 +25,27 @@ using shape_inference::ShapeHandle;
REGISTER_OP("GatherTree")
.Input("step_ids: T")
.Input("parent_ids: T")
- .Input("sequence_length: T")
+ .Input("max_sequence_lengths: int32")
+ .Input("end_token: T")
.Output("beams: T")
.Attr("T: {int32}")
.SetShapeFn([](InferenceContext* c) {
- ShapeHandle step_ids, parent_ids, sequence_length;
+ ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token;
// step_ids, parent_ids, and output are all shaped:
// [max_time, batch_size, beam_width].
- // sequence_length is shaped [batch_size, beam_width].
+ // max_sequence_length is shaped [batch_size] and end_token is a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
-
- DimensionHandle batch_size = c->Dim(step_ids, 1);
- DimensionHandle beam_width = c->Dim(step_ids, 2);
-
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token));
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
+ DimensionHandle batch_size = c->Dim(step_ids, 1);
TF_RETURN_IF_ERROR(
- c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
- TF_RETURN_IF_ERROR(
- c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
+ c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size));
+ ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size);
+ TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids,
+ &step_ids_prefix));
c->set_output(0, step_ids);
return tensorflow::Status::OK();
@@ -61,7 +61,8 @@ TODO(ebrevdo): fill in
step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.
-sequence_length: `[batch_size, beam_width]`.
+max_sequence_lengths: `[batch_size]`.
+end_token: `[]`.
beams: `[max_time, batch_size, beam_width]`.
)doc");
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 8d4ec4b4db..d2beac5f31 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -54,15 +54,18 @@ class TestGatherTree(test.TestCase):
[[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
- # sequence_lengths is shaped (batch_size = 2, beam_width = 3)
- sequence_lengths = [[3, 3, 3], [3, 3, 3]]
+ # sequence_lengths is shaped (batch_size = 3)
+ max_sequence_lengths = [3, 3]
expected_result = np.array(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
[[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
- predicted_ids, parent_ids, sequence_lengths)
+ predicted_ids,
+ parent_ids,
+ max_sequence_lengths=max_sequence_lengths,
+ end_token=11)
with self.test_session() as sess:
res_ = sess.run(res)
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
index 50cccf392f..f301314872 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
+import itertools
+
import numpy as np
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
@@ -38,12 +40,14 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
- sequence_length = [[3, 3, 3]]
+ max_sequence_lengths = [3]
expected_result = _transpose_batch_time(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
beams = beam_search_ops.gather_tree(
- step_ids=step_ids, parent_ids=parent_ids,
- sequence_length=sequence_length)
+ step_ids=step_ids,
+ parent_ids=parent_ids,
+ max_sequence_lengths=max_sequence_lengths,
+ end_token=10)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
@@ -54,11 +58,13 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
- sequence_length = [[3, 3, 3]]
+ max_sequence_lengths = [3]
with ops.device("/cpu:0"):
beams = beam_search_ops.gather_tree(
- step_ids=step_ids, parent_ids=parent_ids,
- sequence_length=sequence_length)
+ step_ids=step_ids,
+ parent_ids=parent_ids,
+ max_sequence_lengths=max_sequence_lengths,
+ end_token=10)
with self.test_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
@@ -75,78 +81,58 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
- sequence_length = [[3, 3, 3]]
+ max_sequence_lengths = [3]
expected_result = _transpose_batch_time(
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
with ops.device("/device:GPU:0"):
beams = beam_search_ops.gather_tree(
- step_ids=step_ids, parent_ids=parent_ids,
- sequence_length=sequence_length)
+ step_ids=step_ids,
+ parent_ids=parent_ids,
+ max_sequence_lengths=max_sequence_lengths,
+ end_token=10)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
def testGatherTreeBatch(self):
- # sequence_length is [batch_size, beam_width] = [4, 5]
- sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]
+ batch_size = 10
+ beam_width = 15
+ max_time = 8
+ max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
+ end_token = 5
with self.test_session(use_gpu=True):
- # (max_time = 4, batch_size = 4, beam_width = 5)
- step_ids = _transpose_batch_time(
- [[[3, 4, 0, 4, 0],
- [4, 2, 0, 3, 1],
- [1, 1, 3, 2, 2],
- [3, 1, 2, 3, 4]],
- [[3, 4, 0, 4, 0],
- [4, 2, 0, 3, 1],
- [1, 1, 3, 2, 2],
- [3, 1, 2, 3, 4]],
- [[1, 2, 3, 4, 2],
- [2, 1, 1, 3, 2],
- [3, 0, 1, 0, 0],
- [3, 4, 0, 2, 4]],
- [[0, 2, 2, 3, 1],
- [3, 2, 2, 2, 3],
- [3, 4, 3, 0, 3],
- [1, 2, 2, 2, 4]]])
- parent_ids = _transpose_batch_time(
- [[[4, 2, 4, 3, 4],
- [3, 4, 0, 2, 0],
- [3, 1, 3, 2, 2],
- [0, 2, 1, 4, 2]],
- [[4, 2, 4, 3, 4],
- [3, 4, 0, 2, 0],
- [3, 1, 3, 2, 2],
- [0, 2, 1, 4, 2]],
- [[3, 0, 0, 4, 0],
- [1, 2, 4, 2, 2],
- [4, 4, 0, 3, 0],
- [2, 4, 4, 3, 0]],
- [[3, 1, 4, 1, 3],
- [3, 2, 4, 0, 4],
- [1, 0, 1, 4, 2],
- [0, 3, 2, 0, 1]]])
- expected_beams = _transpose_batch_time(
- [[[-1, -1, -1, -1, -1],
- [-1, -1, -1, -1, -1],
- [-1, -1, -1, -1, -1],
- [-1, -1, -1, -1, -1]],
- [[3, 4, 0, 4, 0],
- [-1, -1, -1, -1, -1],
- [-1, -1, -1, -1, -1],
- [-1, -1, -1, -1, -1]],
- [[2, 3, 2, 3, 3],
- [2, 1, 1, 3, 2],
- [-1, -1, -1, -1, -1],
- [-1, -1, -1, -1, -1]],
- [[2, 3, 2, 1, 1],
- [2, 3, 2, 3, 2],
- [3, 4, 3, 0, 3],
- [-1, -1, -1, -1, -1]]])
+ step_ids = np.random.randint(
+ 0, high=end_token + 1, size=(max_time, batch_size, beam_width))
+ parent_ids = np.random.randint(
+ 0, high=beam_width - 1, size=(max_time, batch_size, beam_width))
beams = beam_search_ops.gather_tree(
- step_ids=step_ids, parent_ids=parent_ids,
- sequence_length=sequence_length)
- self.assertAllEqual(expected_beams, beams.eval())
+ step_ids=step_ids.astype(np.int32),
+ parent_ids=parent_ids.astype(np.int32),
+ max_sequence_lengths=max_sequence_lengths,
+ end_token=end_token)
+
+ self.assertEqual((max_time, batch_size, beam_width), beams.shape)
+ beams_value = beams.eval()
+ for b in range(batch_size):
+ # Past max_sequence_lengths[b], we emit all -1s.
+ b_value = beams_value[max_sequence_lengths[b]:, b, :]
+ self.assertAllClose(b_value, -1. * np.ones_like(b_value))
+ for batch, beam in itertools.product(
+ range(batch_size), range(beam_width)):
+ v = np.squeeze(beams_value[:, batch, beam])
+ if end_token in v:
+ found = np.where(v == end_token)[0]
+ # Should be up to 1 instance of end_token per beam.
+ self.assertEqual(len(found), 1)
+ found = found[0]
+ # If an end_token is found, everything before it should be a
+ # valid id and everything after it should be -1.
+ if found > 0:
+ self.assertAllEqual(
+ v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
+ self.assertAllClose(
+ v[found + 1:], -1 * np.ones_like(v[found + 1:]))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 112ac57a1b..a88d4f5b8b 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -254,6 +254,20 @@ class BeamSearchDecoder(decoder.Decoder):
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
+ def tracks_own_finished(self):
+ """The BeamSearchDecoder shuffles its beams and their finished state.
+
+ For this reason, it conflicts with the `dynamic_decode` function's
+ tracking of finished states. Setting this property to true avoids
+ early stopping of decoding due to mismanagement of the finished state
+ in `dynamic_decode`.
+
+ Returns:
+ `True`.
+ """
+ return True
+
+ @property
def output_size(self):
# Return the cell output and the id
return BeamSearchDecoderOutput(
@@ -303,15 +317,23 @@ class BeamSearchDecoder(decoder.Decoder):
output.
sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
The sequence lengths determined for each beam during decode.
+ **NOTE** These are ignored; the updated sequence lengths are stored in
+ `final_state.lengths`.
Returns:
- outputs: An instance of FinalBeamSearchDecoderOutput where the
+ outputs: An instance of `FinalBeamSearchDecoderOutput` where the
predicted_ids are the result of calling _gather_tree.
- final_state: The same input instance of BeamSearchDecoderState.
+ final_state: The same input instance of `BeamSearchDecoderState`.
"""
+ del sequence_lengths
+ # Get max_sequence_length across all beams for each batch.
+ max_sequence_lengths = math_ops.to_int32(
+ math_ops.reduce_max(final_state.lengths, axis=1))
predicted_ids = beam_search_ops.gather_tree(
- outputs.predicted_ids, outputs.parent_ids,
- sequence_length=sequence_lengths)
+ outputs.predicted_ids,
+ outputs.parent_ids,
+ max_sequence_lengths=max_sequence_lengths,
+ end_token=self._end_token)
outputs = FinalBeamSearchDecoderOutput(
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
return outputs, final_state
@@ -588,10 +610,11 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
name="next_beam_finished")
# Calculate the length of the next predictions.
- # 1. Finished beams remain unchanged
- # 2. Beams that are now finished (EOS predicted) remain unchanged
- # 3. Beams that are not yet finished have their length increased by 1
- lengths_to_add = math_ops.to_int64(math_ops.logical_not(next_finished))
+ # 1. Finished beams remain unchanged.
+ # 2. Beams that are now finished (EOS predicted) have their length
+ # increased by 1.
+ # 3. Beams that are not yet finished have their length increased by 1.
+ lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished))
next_prediction_len = _tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=beam_state.lengths,
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index fbe53fc60a..f14974b9d5 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -100,16 +100,36 @@ class Decoder(object):
Returns:
`(outputs, next_state, next_inputs, finished)`: `outputs` is an object
- containing the decoder output, `next_state` is a (structure of) state tensors
- and TensorArrays, `next_inputs` is the tensor that should be used as input for
- the next step, `finished` is a boolean tensor telling whether the sequence
- is complete, for each sequence in the batch.
+ containing the decoder output, `next_state` is a (structure of) state
+ tensors and TensorArrays, `next_inputs` is the tensor that should be used
+ as input for the next step, `finished` is a boolean tensor telling whether
+ the sequence is complete, for each sequence in the batch.
"""
raise NotImplementedError
def finalize(self, outputs, final_state, sequence_lengths):
raise NotImplementedError
+ @property
+ def tracks_own_finished(self):
+ """Describes whether the Decoder keeps track of finished states.
+
+ Most decoders will emit a true/false `finished` value independently
+ at each time step. In this case, the `dynamic_decode` function keeps track
+ of which batch entries are already finished, and performs a logical OR to
+ insert new batches to the finished set.
+
+ Some decoders, however, shuffle batches / beams between time steps and
+ `dynamic_decode` will mix up the finished state across these entries because
+ it does not track the reshuffle across time steps. In this case, it is
+ up to the decoder to declare that it will keep track of its own finished
+ state by setting this property to `True`.
+
+ Returns:
+ Python bool.
+ """
+ return False
+
def _create_zero_outputs(size, dtype, batch_size):
"""Create a zero outputs Tensor structure."""
@@ -232,7 +252,10 @@ def dynamic_decode(decoder,
"""
(next_outputs, decoder_state, next_inputs,
decoder_finished) = decoder.step(time, inputs, state)
- next_finished = math_ops.logical_or(decoder_finished, finished)
+ if decoder.tracks_own_finished:
+ next_finished = decoder_finished
+ else:
+ next_finished = math_ops.logical_or(decoder_finished, finished)
if maximum_iterations is not None:
next_finished = math_ops.logical_or(
next_finished, time + 1 >= maximum_iterations)