diff options
Diffstat (limited to 'tensorflow/contrib/crf/python/ops/crf.py')
-rw-r--r-- | tensorflow/contrib/crf/python/ops/crf.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 2a91dcb63a..43bb43129b 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -56,7 +56,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -214,10 +213,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): log_norm) return log_norm - max_seq_len = array_ops.shape(inputs)[1] - return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), - true_fn=_single_seq_fn, - false_fn=_multi_seq_fn) + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or + array_ops.shape(inputs)[1], 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, |