aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar cbockman <c.bockman@gmail.com>2018-08-03 16:58:21 -0700
committerGravatar GitHub <noreply@github.com>2018-08-03 16:58:21 -0700
commita2dc1d52ab068b531e85cabaf5043d920b56c0f4 (patch)
tree43aff5c2b08bc63118837f391a50f49b572f874a /tensorflow/contrib/crf
parent66dd14547dd9edb4eba13d22361ddad4a1cd3353 (diff)
fix var type issue which breaks crf_decode
CRF decode can fail when default type of "0" (as viewed by math_ops.maximum) does not match the type of sequence_length. This change is parallel in motivation and solution to the fix in https://github.com/tensorflow/tensorflow/commit/c0f1080188c5c6955cfa3b3c086ac262b1e5ec02, for crf_log_norm()=>_multi_seq_fn().
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 8a7ff61bc8..2a91dcb63a 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -548,7 +548,9 @@ def crf_decode(potentials, transition_params, sequence_length):
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
# Sequence length is not allowed to be less than zero.
- sequence_length_less_one = math_ops.maximum(0, sequence_length - 1)
+ sequence_length_less_one = math_ops.maximum(
+ constant_op.constant(0, dtype=sequence_length.dtype),
+ sequence_length - 1)
backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O]
crf_fwd_cell,
inputs=inputs,