aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc')
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
index 64973ccccd..dfa12e873a 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
@@ -80,12 +80,12 @@ class GatherTreeOp : public OpKernel {
max_sequence_lengths.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
- typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
- typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
+ typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>());
+ typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>());
typename TTypes<int32>::ConstVec max_seq_lens_t =
max_sequence_lengths.vec<int32>();
- typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
- typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
+ typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>());
+ typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>());
const T end_token_value = end_token_t();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
max_seq_lens_t, end_token_value, beams_t);