aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-21 22:04:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-21 22:08:42 -0700
commitf31939d24e3c544933b98ef48fac9ccac5679e05 (patch)
treea4633dec9f5384d88589fefd43989b74a0ef08a4 /tensorflow/contrib/crf
parent2279279fd15369e361a02fb09a1df41e08a34aae (diff)
Add function crf_multitag_sequence_score which enables calculating scores with more than one tag at each index.
PiperOrigin-RevId: 205551004
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/__init__.py2
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py62
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py52
3 files changed, 109 insertions, 7 deletions
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
index 046c509626..615e62b16f 100644
--- a/tensorflow/contrib/crf/__init__.py
+++ b/tensorflow/contrib/crf/__init__.py
@@ -20,6 +20,7 @@ See the @{$python/contrib.crf} guide.
@@crf_decode
@@crf_log_likelihood
@@crf_log_norm
+@@crf_multitag_sequence_score
@@crf_sequence_score
@@crf_unary_score
@@CrfDecodeBackwardRnnCell
@@ -36,6 +37,7 @@ from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
from tensorflow.contrib.crf.python.ops.crf import crf_decode
from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
+from tensorflow.contrib.crf.python.ops.crf import crf_multitag_sequence_score
from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 74f2ec22ff..f56a973f6f 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -31,6 +31,15 @@ from tensorflow.python.platform import test
class CrfTest(test.TestCase):
+ def calculateSequenceScore(self, inputs, transition_params, tag_indices,
+ sequence_lengths):
+ expected_unary_score = sum(
+ inputs[i][tag_indices[i]] for i in range(sequence_lengths))
+ expected_binary_score = sum(
+ transition_params[tag_indices[i], tag_indices[i + 1]]
+ for i in range(sequence_lengths - 1))
+ return expected_unary_score + expected_binary_score
+
def testCrfSequenceScore(self):
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
@@ -60,14 +69,55 @@ class CrfTest(test.TestCase):
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
tf_sequence_score = sess.run(sequence_score)
- expected_unary_score = sum(inputs[i][tag_indices[i]]
- for i in range(sequence_lengths))
- expected_binary_score = sum(
- transition_params[tag_indices[i], tag_indices[i + 1]]
- for i in range(sequence_lengths - 1))
- expected_sequence_score = expected_unary_score + expected_binary_score
+ expected_sequence_score = self.calculateSequenceScore(
+ inputs, transition_params, tag_indices, sequence_lengths)
self.assertAllClose(tf_sequence_score, expected_sequence_score)
+ def testCrfMultiTagSequenceScore(self):
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ # Test both the length-1 and regular cases.
+ sequence_lengths_list = [
+ np.array(3, dtype=np.int32),
+ np.array(1, dtype=np.int32)
+ ]
+ inputs_list = [
+ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
+ dtype=np.float32),
+ np.array([[4, 5, -3]],
+ dtype=np.float32),
+ ]
+ tag_bitmap_list = [
+ np.array(
+ [[True, True, False], [True, False, True], [False, True, True],
+ [True, False, True]],
+ dtype=np.bool),
+ np.array([[True, True, False]], dtype=np.bool)
+ ]
+ for sequence_lengths, inputs, tag_bitmap in zip(
+ sequence_lengths_list, inputs_list, tag_bitmap_list):
+ with self.test_session() as sess:
+ sequence_score = crf.crf_multitag_sequence_score(
+ inputs=array_ops.expand_dims(inputs, 0),
+ tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params))
+ sequence_score = array_ops.squeeze(sequence_score, [0])
+ tf_sum_sequence_score = sess.run(sequence_score)
+ all_indices_list = [
+ single_index_bitmap.nonzero()[0]
+ for single_index_bitmap in tag_bitmap[:sequence_lengths]
+ ]
+ expected_sequence_scores = [
+ self.calculateSequenceScore(inputs, transition_params, indices,
+ sequence_lengths)
+ for indices in itertools.product(*all_indices_list)
+ ]
+ expected_log_sum_exp_sequence_scores = np.logaddexp.reduce(
+ expected_sequence_scores)
+ self.assertAllClose(tf_sum_sequence_score,
+ expected_log_sum_exp_sequence_scores)
+
def testCrfUnaryScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 2d2cbdc199..8a7ff61bc8 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -67,7 +67,7 @@ __all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
- "CrfDecodeBackwardRnnCell"
+ "CrfDecodeBackwardRnnCell", "crf_multitag_sequence_score"
]
@@ -114,6 +114,56 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
false_fn=_multi_seq_fn)
+def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths,
+ transition_params):
+ """Computes the unnormalized score of all tag sequences matching tag_bitmap.
+
+ tag_bitmap enables more than one tag to be considered correct at each time
+ step. This is useful when an observed output at a given time step is
+ consistent with more than one tag, and thus the log likelihood of that
+ observation must take into account all possible consistent tags.
+
+ Using one-hot vectors in tag_bitmap gives results identical to
+ crf_sequence_score.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor
+ representing all active tags at each index for which to calculate the
+ unnormalized score.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] transition matrix.
+ Returns:
+ sequence_scores: A [batch_size] vector of unnormalized sequence scores.
+ """
+
+ # If max_seq_len is 1, we skip the score calculation and simply gather the
+ # unary potentials of all active tags.
+ def _single_seq_fn():
+ filtered_inputs = array_ops.where(
+ tag_bitmap, inputs,
+ array_ops.fill(array_ops.shape(inputs), float("-inf")))
+ return math_ops.reduce_logsumexp(
+ filtered_inputs, axis=[1, 2], keepdims=False)
+
+ def _multi_seq_fn():
+ # Compute the logsumexp of all scores of sequences matching the given tags.
+ filtered_inputs = array_ops.where(
+ tag_bitmap, inputs,
+ array_ops.fill(array_ops.shape(inputs), float("-inf")))
+ return crf_log_norm(
+ inputs=filtered_inputs,
+ sequence_lengths=sequence_lengths,
+ transition_params=transition_params)
+
+ return utils.smart_cond(
+ pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
+ 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
+
+
def crf_log_norm(inputs, sequence_lengths, transition_params):
"""Computes the normalization for a CRF.