aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-28 20:51:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 20:55:49 -0700
commit2e0e934e0b3c00863918c78bf55524eea3f0c0dc (patch)
tree61f464403123a3eeb546f0be46307e1181288ed1 /tensorflow/contrib/crf
parentb5c66300d2c15a9bf1a8631161efa1a057e6ed31 (diff)
Make tf.contrib.crf compatible with TPUs by using utils.smart_cond instead of tf.cond, which allows the static shape to be propagated correctly when available.
PiperOrigin-RevId: 215034102
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 2a91dcb63a..43bb43129b 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -56,7 +56,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
@@ -214,10 +213,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
log_norm)
return log_norm
- max_seq_len = array_ops.shape(inputs)[1]
- return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
- true_fn=_single_seq_fn,
- false_fn=_multi_seq_fn)
+ return utils.smart_cond(
+ pred=math_ops.equal(inputs.shape[1].value or
+ array_ops.shape(inputs)[1], 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs,