diff options
author | 2017-11-22 12:30:20 -0800 | |
---|---|---|
committer | 2017-11-22 12:36:14 -0800 | |
commit | 8752c973150df64374f96d516aafa664de410dce (patch) | |
tree | 0dfd7f39447a85fdf50bb4c17c7970791f17573e /tensorflow/contrib/crf | |
parent | d9b3ed25816f98e8ad11d3ecb20c1fc0ed0f4166 (diff) |
Fix functionality in crf_sequence_score(), crf_log_norm(), and crf_decode() for when input has max_seq_len = 1. This can happen in single-example inference.
PiperOrigin-RevId: 176688502
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r-- | tensorflow/contrib/crf/python/kernel_tests/crf_test.py | 224 | ||||
-rw-r--r-- | tensorflow/contrib/crf/python/ops/crf.py | 163 |
2 files changed, 242 insertions, 145 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 964ec75441..b47fb426a1 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -32,27 +32,41 @@ from tensorflow.python.platform import test class CrfTest(test.TestCase): def testCrfSequenceScore(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - tf_sequence_score = sess.run(sequence_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - expected_sequence_score = expected_unary_score + expected_binary_score - self.assertAllClose(tf_sequence_score, expected_sequence_score) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + with self.test_session() as sess: + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + tf_sequence_score = sess.run(sequence_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + expected_sequence_score = expected_unary_score + expected_binary_score + self.assertAllClose(tf_sequence_score, expected_sequence_score) def testCrfUnaryScore(self): inputs = np.array( @@ -89,38 +103,54 @@ class CrfTest(test.TestCase): self.assertAllClose(tf_binary_score, expected_binary_score) def testCrfLogNorm(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - all_sequence_scores = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequence_scores.append( - crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params))) - - brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) - log_norm = crf.crf_log_norm( - inputs=array_ops.expand_dims(inputs, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - log_norm = array_ops.squeeze(log_norm, [0]) - tf_brute_force_log_norm, tf_log_norm = sess.run( - [brute_force_log_norm, log_norm]) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + with self.test_session() as sess: + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params))) + + brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) + log_norm = crf.crf_log_norm( + inputs=array_ops.expand_dims(inputs, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + log_norm = array_ops.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = sess.run( + [brute_force_log_norm, log_norm]) - self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) def testCrfLogLikelihood(self): inputs = np.array( @@ -201,50 +231,66 @@ class CrfTest(test.TestCase): expected_max_sequence[:sequence_lengths]) def testCrfDecode(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] - with self.test_session() as sess: - all_sequence_scores = [] - all_sequences = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequences.append(tag_indices) - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - all_sequence_scores.append(sequence_score) - - tf_all_sequence_scores = sess.run(all_sequence_scores) - - expected_max_sequence_index = np.argmax(tf_all_sequence_scores) - expected_max_sequence = all_sequences[expected_max_sequence_index] - expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] - - actual_max_sequence, actual_max_score = crf.crf_decode( - array_ops.expand_dims(inputs, 0), - constant_op.constant(transition_params), - array_ops.expand_dims(sequence_lengths, 0)) - actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) - actual_max_score = array_ops.squeeze(actual_max_score, [0]) - tf_actual_max_sequence, tf_actual_max_score = sess.run( - [actual_max_sequence, actual_max_score]) - - self.assertAllClose(tf_actual_max_score, expected_max_score) - self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), - expected_max_sequence[:sequence_lengths]) + with self.test_session() as sess: + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = sess.run(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = crf.crf_decode( + array_ops.expand_dims(inputs, 0), + constant_op.constant(transition_params), + array_ops.expand_dims(sequence_lengths, 0)) + actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) + actual_max_score = array_ops.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = sess.run( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) if __name__ == "__main__": diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 4282be5ec8..ca384226d4 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -53,7 +53,9 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ - # Compute the scores of the given tag sequence. - unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) - binary_scores = crf_binary_score(tag_indices, sequence_lengths, - transition_params) - sequence_scores = unary_scores + binary_scores - return sequence_scores + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] + example_inds = array_ops.reshape( + math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + return array_ops.gather_nd( + array_ops.squeeze(inputs, [1]), + array_ops.concat([example_inds, tag_indices], axis=1)) + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], + 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # algorithm. first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.squeeze(first_input, [1]) - rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) - # Compute the alpha values in the forward algorithm in order to get the - # partition function. - forward_cell = CrfForwardRnnCell(transition_params) - _, alphas = rnn.dynamic_rnn( - cell=forward_cell, - inputs=rest_of_input, - sequence_length=sequence_lengths - 1, - initial_state=first_input, - dtype=dtypes.float32) - log_norm = math_ops.reduce_logsumexp(alphas, [1]) - return log_norm + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + return math_ops.reduce_logsumexp(first_input, [1]) + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) + + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + forward_cell = CrfForwardRnnCell(transition_params) + _, alphas = rnn.dynamic_rnn( + cell=forward_cell, + inputs=rest_of_input, + sequence_length=sequence_lengths - 1, + initial_state=first_input, + dtype=dtypes.float32) + log_norm = math_ops.reduce_logsumexp(alphas, [1]) + return log_norm + + max_seq_len = array_ops.shape(inputs)[1] + return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, @@ -440,41 +472,60 @@ def crf_decode(potentials, transition_params, sequence_length): Contains the highest scoring tag indices. best_score: A [batch_size] tensor, containing the score of decode_tags. """ - # For simplicity, in shape comments, denote: - # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.get_shape()[2].value - - # Computes forward decoding. Get last score and backpointers. - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) - initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) - initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] - inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - backpointers, last_score = rnn.dynamic_rnn( - crf_fwd_cell, - inputs=inputs, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, O], [B, O] - backpointers = gen_array_ops.reverse_sequence( - backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] - - # Computes backward decoding. Extract tag indices from backpointers. - crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) - initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), - dtype=dtypes.int32) # [B] - initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] - decode_tags, _ = rnn.dynamic_rnn( - crf_bwd_cell, - inputs=backpointers, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, 1] - decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] - decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] - decode_tags = gen_array_ops.reverse_sequence( - decode_tags, sequence_length, seq_dim=1) # [B, T] - - best_score = math_ops.reduce_max(last_score, axis=1) # [B] - return decode_tags, best_score + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = array_ops.squeeze(potentials, [1]) + decode_tags = array_ops.expand_dims( + math_ops.argmax(squeezed_potentials, axis=1), 1) + best_score = math_ops.reduce_max(squeezed_potentials, axis=1) + return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + + # For simplicity, in shape comments, denote: + # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). + num_tags = potentials.get_shape()[2].value + + # Computes forward decoding. Get last score and backpointers. + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] + inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] + crf_fwd_cell, + inputs=inputs, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] + backpointers, sequence_length - 1, seq_dim=1) + + # Computes backward decoding. Extract tag indices from backpointers. + crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) + initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] + dtype=dtypes.int32) + initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] + decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] + crf_bwd_cell, + inputs=backpointers, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] + decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] + axis=1) + decode_tags = gen_array_ops.reverse_sequence( # [B, T] + decode_tags, sequence_length, seq_dim=1) + + best_score = math_ops.reduce_max(last_score, axis=1) # [B] + return decode_tags, best_score + + return utils.smart_cond( + pred=math_ops.equal( + potentials.shape[1].value or array_ops.shape(potentials)[1], 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) |