aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/crf/python/kernel_tests/crf_test.py')
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py62
1 files changed, 56 insertions, 6 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 74f2ec22ff..f56a973f6f 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -31,6 +31,15 @@ from tensorflow.python.platform import test
class CrfTest(test.TestCase):
+ def calculateSequenceScore(self, inputs, transition_params, tag_indices,
+ sequence_lengths):
+ expected_unary_score = sum(
+ inputs[i][tag_indices[i]] for i in range(sequence_lengths))
+ expected_binary_score = sum(
+ transition_params[tag_indices[i], tag_indices[i + 1]]
+ for i in range(sequence_lengths - 1))
+ return expected_unary_score + expected_binary_score
+
def testCrfSequenceScore(self):
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
@@ -60,14 +69,55 @@ class CrfTest(test.TestCase):
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
tf_sequence_score = sess.run(sequence_score)
- expected_unary_score = sum(inputs[i][tag_indices[i]]
- for i in range(sequence_lengths))
- expected_binary_score = sum(
- transition_params[tag_indices[i], tag_indices[i + 1]]
- for i in range(sequence_lengths - 1))
- expected_sequence_score = expected_unary_score + expected_binary_score
+ expected_sequence_score = self.calculateSequenceScore(
+ inputs, transition_params, tag_indices, sequence_lengths)
self.assertAllClose(tf_sequence_score, expected_sequence_score)
+ def testCrfMultiTagSequenceScore(self):
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ # Test both the length-1 and regular cases.
+ sequence_lengths_list = [
+ np.array(3, dtype=np.int32),
+ np.array(1, dtype=np.int32)
+ ]
+ inputs_list = [
+ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
+ dtype=np.float32),
+ np.array([[4, 5, -3]],
+ dtype=np.float32),
+ ]
+ tag_bitmap_list = [
+ np.array(
+ [[True, True, False], [True, False, True], [False, True, True],
+ [True, False, True]],
+ dtype=np.bool),
+ np.array([[True, True, False]], dtype=np.bool)
+ ]
+ for sequence_lengths, inputs, tag_bitmap in zip(
+ sequence_lengths_list, inputs_list, tag_bitmap_list):
+ with self.test_session() as sess:
+ sequence_score = crf.crf_multitag_sequence_score(
+ inputs=array_ops.expand_dims(inputs, 0),
+ tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params))
+ sequence_score = array_ops.squeeze(sequence_score, [0])
+ tf_sum_sequence_score = sess.run(sequence_score)
+ all_indices_list = [
+ single_index_bitmap.nonzero()[0]
+ for single_index_bitmap in tag_bitmap[:sequence_lengths]
+ ]
+ expected_sequence_scores = [
+ self.calculateSequenceScore(inputs, transition_params, indices,
+ sequence_lengths)
+ for indices in itertools.product(*all_indices_list)
+ ]
+ expected_log_sum_exp_sequence_scores = np.logaddexp.reduce(
+ expected_sequence_scores)
+ self.assertAllClose(tf_sum_sequence_score,
+ expected_log_sum_exp_sequence_scores)
+
def testCrfUnaryScore(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)