diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-28 20:51:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 20:55:49 -0700 |
commit | 2e0e934e0b3c00863918c78bf55524eea3f0c0dc (patch) | |
tree | 61f464403123a3eeb546f0be46307e1181288ed1 /tensorflow/contrib/crf | |
parent | b5c66300d2c15a9bf1a8631161efa1a057e6ed31 (diff) |
Make tf.contrib.crf compatible with TPUs by using utils.smart_cond instead of tf.cond, which allows the static shape to be propagated correctly when available.
PiperOrigin-RevId: 215034102
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r-- | tensorflow/contrib/crf/python/ops/crf.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 2a91dcb63a..43bb43129b 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -56,7 +56,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -214,10 +213,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): log_norm) return log_norm - max_seq_len = array_ops.shape(inputs)[1] - return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), - true_fn=_single_seq_fn, - false_fn=_multi_seq_fn) + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or + array_ops.shape(inputs)[1], 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, |