diff options
author | Patrick Nguyen <drpng@google.com> | 2018-05-01 19:02:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-01 19:04:52 -0700 |
commit | c0f1080188c5c6955cfa3b3c086ac262b1e5ec02 (patch) | |
tree | 819ea50344584528918469391811cf9d792d91a4 /tensorflow/contrib/crf | |
parent | 69b2c639f55b065a5dbf829351034441bebc8437 (diff) |
Make the CRF work when sequence_lengths are int32.
PiperOrigin-RevId: 195034218
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r-- | tensorflow/contrib/crf/python/ops/crf.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index d2beff849e..2d2cbdc199 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -52,6 +52,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -147,7 +148,9 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # partition function. forward_cell = CrfForwardRnnCell(transition_params) # Sequence length is not allowed to be less than zero. - sequence_lengths_less_one = math_ops.maximum(0, sequence_lengths - 1) + sequence_lengths_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_lengths.dtype), + sequence_lengths - 1) _, alphas = rnn.dynamic_rnn( cell=forward_cell, inputs=rest_of_input, |