diff options
Diffstat (limited to 'tensorflow/contrib/seq2seq/kernels')
-rw-r--r-- | tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 64973ccccd..dfa12e873a 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -80,12 +80,12 @@ class GatherTreeOp : public OpKernel { max_sequence_lengths.shape().DebugString())); Tensor* beams; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); - typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>(); - typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>(); + typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>()); + typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>()); typename TTypes<int32>::ConstVec max_seq_lens_t = max_sequence_lengths.vec<int32>(); - typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>(); - typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>(); + typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>()); + typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>()); const T end_token_value = end_token_t(); functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t, max_seq_lens_t, end_token_value, beams_t); |