aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-05-01 19:02:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 19:04:52 -0700
commitc0f1080188c5c6955cfa3b3c086ac262b1e5ec02 (patch)
tree819ea50344584528918469391811cf9d792d91a4 /tensorflow/contrib/crf
parent69b2c639f55b065a5dbf829351034441bebc8437 (diff)
Make the CRF work when sequence_lengths are int32.
PiperOrigin-RevId: 195034218
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index d2beff849e..2d2cbdc199 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -52,6 +52,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
@@ -147,7 +148,9 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
# partition function.
forward_cell = CrfForwardRnnCell(transition_params)
# Sequence length is not allowed to be less than zero.
- sequence_lengths_less_one = math_ops.maximum(0, sequence_lengths - 1)
+ sequence_lengths_less_one = math_ops.maximum(
+ constant_op.constant(0, dtype=sequence_lengths.dtype),
+ sequence_lengths - 1)
_, alphas = rnn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,