aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/crf/python/kernel_tests/crf_test.py')
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 448bcafffe..9174c5eb98 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -23,6 +23,7 @@ import itertools
import numpy as np
from tensorflow.contrib.crf.python.ops import crf
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -199,6 +200,52 @@ class CrfTest(test.TestCase):
self.assertEqual(actual_max_sequence,
expected_max_sequence[:sequence_lengths])
+ def testCrfDecode(self):
+ inputs = np.array(
+ [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ sequence_lengths = np.array(3, dtype=np.int32)
+ num_words = inputs.shape[0]
+ num_tags = inputs.shape[1]
+
+ with self.test_session() as sess:
+ all_sequence_scores = []
+ all_sequences = []
+
+ # Compare the dynamic program with brute force computation.
+ for tag_indices in itertools.product(
+ range(num_tags), repeat=sequence_lengths):
+ tag_indices = list(tag_indices)
+ tag_indices.extend([0] * (num_words - sequence_lengths))
+ all_sequences.append(tag_indices)
+ sequence_score = crf.crf_sequence_score(
+ inputs=array_ops.expand_dims(inputs, 0),
+ tag_indices=array_ops.expand_dims(tag_indices, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params))
+ sequence_score = array_ops.squeeze(sequence_score, [0])
+ all_sequence_scores.append(sequence_score)
+
+ tf_all_sequence_scores = sess.run(all_sequence_scores)
+
+ expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
+ expected_max_sequence = all_sequences[expected_max_sequence_index]
+ expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
+
+ actual_max_sequence, actual_max_score = crf.crf_decode(
+ array_ops.expand_dims(inputs, 0),
+ constant_op.constant(transition_params),
+ array_ops.expand_dims(sequence_lengths, 0))
+ actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
+ actual_max_score = array_ops.squeeze(actual_max_score, [0])
+ tf_actual_max_sequence, tf_actual_max_score = sess.run(
+ [actual_max_sequence, actual_max_score])
+
+ self.assertAllClose(tf_actual_max_score, expected_max_score)
+ self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
+ expected_max_sequence[:sequence_lengths])
+
if __name__ == "__main__":
test.main()