diff options
Diffstat (limited to 'tensorflow/contrib/crf/python/kernel_tests/crf_test.py')
-rw-r--r-- | tensorflow/contrib/crf/python/kernel_tests/crf_test.py | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 448bcafffe..9174c5eb98 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -23,6 +23,7 @@ import itertools import numpy as np from tensorflow.contrib.crf.python.ops import crf +from tensorflow.python.framework import dtypes from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -199,6 +200,52 @@ class CrfTest(test.TestCase): self.assertEqual(actual_max_sequence, expected_max_sequence[:sequence_lengths]) + def testCrfDecode(self): + inputs = np.array( + [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + with self.test_session() as sess: + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = sess.run(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = crf.crf_decode( + array_ops.expand_dims(inputs, 0), + constant_op.constant(transition_params), + array_ops.expand_dims(sequence_lengths, 0)) + actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) + actual_max_score = array_ops.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = sess.run( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) + if __name__ == "__main__": test.main() |