diff options
author | cbockman <c.bockman@gmail.com> | 2018-08-03 16:58:21 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-03 16:58:21 -0700 |
commit | a2dc1d52ab068b531e85cabaf5043d920b56c0f4 (patch) | |
tree | 43aff5c2b08bc63118837f391a50f49b572f874a /tensorflow/contrib/crf | |
parent | 66dd14547dd9edb4eba13d22361ddad4a1cd3353 (diff) |
fix var type issue which breaks crf_decode
CRF decode can fail when default type of "0" (as viewed by math_ops.maximum) does not match the type of sequence_length.
This change is parallel in motivation and solution to the fix in https://github.com/tensorflow/tensorflow/commit/c0f1080188c5c6955cfa3b3c086ac262b1e5ec02, for crf_log_norm()=>_multi_seq_fn().
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r-- | tensorflow/contrib/crf/python/ops/crf.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 8a7ff61bc8..2a91dcb63a 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -548,7 +548,9 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. - sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) + sequence_length_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_length.dtype), + sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, |