aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf/python/ops/crf.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/crf/python/ops/crf.py')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py52
1 files changed, 51 insertions, 1 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 2d2cbdc199..8a7ff61bc8 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -67,7 +67,7 @@ __all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
- "CrfDecodeBackwardRnnCell"
+ "CrfDecodeBackwardRnnCell", "crf_multitag_sequence_score"
]
@@ -114,6 +114,56 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
false_fn=_multi_seq_fn)
+def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
+ transition_params):
+ """Computes the unnormalized score of all tag sequences matching tag_bitmap.
+
+ tag_bitmap enables more than one tag to be considered correct at each time
+ step. This is useful when an observed output at a given time step is
+ consistent with more than one tag, and thus the log likelihood of that
+ observation must take into account all possible consistent tags.
+
+ Using one-hot vectors in tag_bitmap gives results identical to
+ crf_sequence_score.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
+ representing all active tags at each index for which to calculate the
+ unnormalized score.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] transition matrix.
+ Returns:
+ sequence_scores: A [batch_size] vector of unnormalized sequence scores.
+ """
+
+ # If max_seq_len is 1, we skip the score calculation and simply gather the
+ # unary potentials of all active tags.
+ def _single_seq_fn():
+ filtered_inputs = array_ops.where(
+ tag_bitmap, inputs,
+ array_ops.fill(array_ops.shape(inputs), float("-inf")))
+ return math_ops.reduce_logsumexp(
+ filtered_inputs, axis=[1, 2], keepdims=False)
+
+ def _multi_seq_fn():
+ # Compute the logsumexp of all scores of sequences matching the given tags.
+ filtered_inputs = array_ops.where(
+ tag_bitmap, inputs,
+ array_ops.fill(array_ops.shape(inputs), float("-inf")))
+ return crf_log_norm(
+ inputs=filtered_inputs,
+ sequence_lengths=sequence_lengths,
+ transition_params=transition_params)
+
+ return utils.smart_cond(
+ pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
+ 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
+
+
def crf_log_norm(inputs, sequence_lengths, transition_params):
"""Computes the normalization for a CRF.