aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-22 12:30:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-22 12:36:14 -0800
commit8752c973150df64374f96d516aafa664de410dce (patch)
tree0dfd7f39447a85fdf50bb4c17c7970791f17573e /tensorflow/contrib/crf
parentd9b3ed25816f98e8ad11d3ecb20c1fc0ed0f4166 (diff)
Fix functionality in crf_sequence_score(), crf_log_norm(), and crf_decode() for when input has max_seq_len = 1. This can happen in single-example inference.
PiperOrigin-RevId: 176688502
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py224
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py163
2 files changed, 242 insertions, 145 deletions
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 964ec75441..b47fb426a1 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -32,27 +32,41 @@ from tensorflow.python.platform import test
class CrfTest(test.TestCase):
def testCrfSequenceScore(self):
- inputs = np.array(
- [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
- tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
- sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
- sequence_score = crf.crf_sequence_score(
- inputs=array_ops.expand_dims(inputs, 0),
- tag_indices=array_ops.expand_dims(tag_indices, 0),
- sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
- transition_params=constant_op.constant(transition_params))
- sequence_score = array_ops.squeeze(sequence_score, [0])
- tf_sequence_score = sess.run(sequence_score)
- expected_unary_score = sum(inputs[i][tag_indices[i]]
- for i in range(sequence_lengths))
- expected_binary_score = sum(
- transition_params[tag_indices[i], tag_indices[i + 1]]
- for i in range(sequence_lengths - 1))
- expected_sequence_score = expected_unary_score + expected_binary_score
- self.assertAllClose(tf_sequence_score, expected_sequence_score)
+ # Test both the length-1 and regular cases.
+ sequence_lengths_list = [
+ np.array(3, dtype=np.int32),
+ np.array(1, dtype=np.int32)
+ ]
+ inputs_list = [
+ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
+ dtype=np.float32),
+ np.array([[4, 5, -3]],
+ dtype=np.float32),
+ ]
+ tag_indices_list = [
+ np.array([1, 2, 1, 0], dtype=np.int32),
+ np.array([1], dtype=np.int32)
+ ]
+ for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
+ inputs_list,
+ tag_indices_list):
+ with self.test_session() as sess:
+ sequence_score = crf.crf_sequence_score(
+ inputs=array_ops.expand_dims(inputs, 0),
+ tag_indices=array_ops.expand_dims(tag_indices, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params))
+ sequence_score = array_ops.squeeze(sequence_score, [0])
+ tf_sequence_score = sess.run(sequence_score)
+ expected_unary_score = sum(inputs[i][tag_indices[i]]
+ for i in range(sequence_lengths))
+ expected_binary_score = sum(
+ transition_params[tag_indices[i], tag_indices[i + 1]]
+ for i in range(sequence_lengths - 1))
+ expected_sequence_score = expected_unary_score + expected_binary_score
+ self.assertAllClose(tf_sequence_score, expected_sequence_score)
def testCrfUnaryScore(self):
inputs = np.array(
@@ -89,38 +103,54 @@ class CrfTest(test.TestCase):
self.assertAllClose(tf_binary_score, expected_binary_score)
def testCrfLogNorm(self):
- inputs = np.array(
- [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
- num_words = inputs.shape[0]
- num_tags = inputs.shape[1]
- sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
- all_sequence_scores = []
-
- # Compare the dynamic program with brute force computation.
- for tag_indices in itertools.product(
- range(num_tags), repeat=sequence_lengths):
- tag_indices = list(tag_indices)
- tag_indices.extend([0] * (num_words - sequence_lengths))
- all_sequence_scores.append(
- crf.crf_sequence_score(
- inputs=array_ops.expand_dims(inputs, 0),
- tag_indices=array_ops.expand_dims(tag_indices, 0),
- sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
- transition_params=constant_op.constant(transition_params)))
-
- brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
- log_norm = crf.crf_log_norm(
- inputs=array_ops.expand_dims(inputs, 0),
- sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
- transition_params=constant_op.constant(transition_params))
- log_norm = array_ops.squeeze(log_norm, [0])
- tf_brute_force_log_norm, tf_log_norm = sess.run(
- [brute_force_log_norm, log_norm])
+ # Test both the length-1 and regular cases.
+ sequence_lengths_list = [
+ np.array(3, dtype=np.int32),
+ np.array(1, dtype=np.int32)
+ ]
+ inputs_list = [
+ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
+ dtype=np.float32),
+ np.array([[3, -1, 3]],
+ dtype=np.float32),
+ ]
+ tag_indices_list = [
+ np.array([1, 2, 1, 0], dtype=np.int32),
+ np.array([2], dtype=np.int32)
+ ]
+
+ for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
+ inputs_list,
+ tag_indices_list):
+ num_words = inputs.shape[0]
+ num_tags = inputs.shape[1]
+ with self.test_session() as sess:
+ all_sequence_scores = []
+
+ # Compare the dynamic program with brute force computation.
+ for tag_indices in itertools.product(
+ range(num_tags), repeat=sequence_lengths):
+ tag_indices = list(tag_indices)
+ tag_indices.extend([0] * (num_words - sequence_lengths))
+ all_sequence_scores.append(
+ crf.crf_sequence_score(
+ inputs=array_ops.expand_dims(inputs, 0),
+ tag_indices=array_ops.expand_dims(tag_indices, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params)))
+
+ brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
+ log_norm = crf.crf_log_norm(
+ inputs=array_ops.expand_dims(inputs, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params))
+ log_norm = array_ops.squeeze(log_norm, [0])
+ tf_brute_force_log_norm, tf_log_norm = sess.run(
+ [brute_force_log_norm, log_norm])
- self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
+ self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
def testCrfLogLikelihood(self):
inputs = np.array(
@@ -201,50 +231,66 @@ class CrfTest(test.TestCase):
expected_max_sequence[:sequence_lengths])
def testCrfDecode(self):
- inputs = np.array(
- [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
- sequence_lengths = np.array(3, dtype=np.int32)
- num_words = inputs.shape[0]
- num_tags = inputs.shape[1]
+ # Test both the length-1 and regular cases.
+ sequence_lengths_list = [
+ np.array(3, dtype=np.int32),
+ np.array(1, dtype=np.int32)
+ ]
+ inputs_list = [
+ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
+ dtype=np.float32),
+ np.array([[-1, 2, 1]],
+ dtype=np.float32),
+ ]
+ tag_indices_list = [
+ np.array([1, 2, 1, 0], dtype=np.int32),
+ np.array([2], dtype=np.int32)
+ ]
+
+ for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
+ inputs_list,
+ tag_indices_list):
+ num_words = inputs.shape[0]
+ num_tags = inputs.shape[1]
- with self.test_session() as sess:
- all_sequence_scores = []
- all_sequences = []
-
- # Compare the dynamic program with brute force computation.
- for tag_indices in itertools.product(
- range(num_tags), repeat=sequence_lengths):
- tag_indices = list(tag_indices)
- tag_indices.extend([0] * (num_words - sequence_lengths))
- all_sequences.append(tag_indices)
- sequence_score = crf.crf_sequence_score(
- inputs=array_ops.expand_dims(inputs, 0),
- tag_indices=array_ops.expand_dims(tag_indices, 0),
- sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
- transition_params=constant_op.constant(transition_params))
- sequence_score = array_ops.squeeze(sequence_score, [0])
- all_sequence_scores.append(sequence_score)
-
- tf_all_sequence_scores = sess.run(all_sequence_scores)
-
- expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
- expected_max_sequence = all_sequences[expected_max_sequence_index]
- expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
-
- actual_max_sequence, actual_max_score = crf.crf_decode(
- array_ops.expand_dims(inputs, 0),
- constant_op.constant(transition_params),
- array_ops.expand_dims(sequence_lengths, 0))
- actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
- actual_max_score = array_ops.squeeze(actual_max_score, [0])
- tf_actual_max_sequence, tf_actual_max_score = sess.run(
- [actual_max_sequence, actual_max_score])
-
- self.assertAllClose(tf_actual_max_score, expected_max_score)
- self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
- expected_max_sequence[:sequence_lengths])
+ with self.test_session() as sess:
+ all_sequence_scores = []
+ all_sequences = []
+
+ # Compare the dynamic program with brute force computation.
+ for tag_indices in itertools.product(
+ range(num_tags), repeat=sequence_lengths):
+ tag_indices = list(tag_indices)
+ tag_indices.extend([0] * (num_words - sequence_lengths))
+ all_sequences.append(tag_indices)
+ sequence_score = crf.crf_sequence_score(
+ inputs=array_ops.expand_dims(inputs, 0),
+ tag_indices=array_ops.expand_dims(tag_indices, 0),
+ sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
+ transition_params=constant_op.constant(transition_params))
+ sequence_score = array_ops.squeeze(sequence_score, [0])
+ all_sequence_scores.append(sequence_score)
+
+ tf_all_sequence_scores = sess.run(all_sequence_scores)
+
+ expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
+ expected_max_sequence = all_sequences[expected_max_sequence_index]
+ expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
+
+ actual_max_sequence, actual_max_score = crf.crf_decode(
+ array_ops.expand_dims(inputs, 0),
+ constant_op.constant(transition_params),
+ array_ops.expand_dims(sequence_lengths, 0))
+ actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
+ actual_max_score = array_ops.squeeze(actual_max_score, [0])
+ tf_actual_max_sequence, tf_actual_max_score = sess.run(
+ [actual_max_sequence, actual_max_score])
+
+ self.assertAllClose(tf_actual_max_score, expected_max_score)
+ self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
+ expected_max_sequence[:sequence_lengths])
if __name__ == "__main__":
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 4282be5ec8..ca384226d4 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -53,7 +53,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
+from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
@@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
- # Compute the scores of the given tag sequence.
- unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
- binary_scores = crf_binary_score(tag_indices, sequence_lengths,
- transition_params)
- sequence_scores = unary_scores + binary_scores
- return sequence_scores
+ # If max_seq_len is 1, we skip the score calculation and simply gather the
+ # unary potentials of the single tag.
+ def _single_seq_fn():
+ batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
+ example_inds = array_ops.reshape(
+ math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
+ return array_ops.gather_nd(
+ array_ops.squeeze(inputs, [1]),
+ array_ops.concat([example_inds, tag_indices], axis=1))
+
+ def _multi_seq_fn():
+ # Compute the scores of the given tag sequence.
+ unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
+ binary_scores = crf_binary_score(tag_indices, sequence_lengths,
+ transition_params)
+ sequence_scores = unary_scores + binary_scores
+ return sequence_scores
+
+ return utils.smart_cond(
+ pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
+ 1),
+ fn1=_single_seq_fn,
+ fn2=_multi_seq_fn)
def crf_log_norm(inputs, sequence_lengths, transition_params):
@@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
# algorithm.
first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = array_ops.squeeze(first_input, [1])
- rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
- # Compute the alpha values in the forward algorithm in order to get the
- # partition function.
- forward_cell = CrfForwardRnnCell(transition_params)
- _, alphas = rnn.dynamic_rnn(
- cell=forward_cell,
- inputs=rest_of_input,
- sequence_length=sequence_lengths - 1,
- initial_state=first_input,
- dtype=dtypes.float32)
- log_norm = math_ops.reduce_logsumexp(alphas, [1])
- return log_norm
+ # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over
+ # the "initial state" (the unary potentials).
+ def _single_seq_fn():
+ return math_ops.reduce_logsumexp(first_input, [1])
+
+ def _multi_seq_fn():
+ """Forward computation of alpha values."""
+ rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
+
+ # Compute the alpha values in the forward algorithm in order to get the
+ # partition function.
+ forward_cell = CrfForwardRnnCell(transition_params)
+ _, alphas = rnn.dynamic_rnn(
+ cell=forward_cell,
+ inputs=rest_of_input,
+ sequence_length=sequence_lengths - 1,
+ initial_state=first_input,
+ dtype=dtypes.float32)
+ log_norm = math_ops.reduce_logsumexp(alphas, [1])
+ return log_norm
+
+ max_seq_len = array_ops.shape(inputs)[1]
+ return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs,
@@ -440,41 +472,60 @@ def crf_decode(potentials, transition_params, sequence_length):
Contains the highest scoring tag indices.
best_score: A [batch_size] tensor, containing the score of decode_tags.
"""
- # For simplicity, in shape comments, denote:
- # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
- num_tags = potentials.get_shape()[2].value
-
- # Computes forward decoding. Get last score and backpointers.
- crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
- initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
- initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
- inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
- backpointers, last_score = rnn.dynamic_rnn(
- crf_fwd_cell,
- inputs=inputs,
- sequence_length=sequence_length - 1,
- initial_state=initial_state,
- time_major=False,
- dtype=dtypes.int32) # [B, T - 1, O], [B, O]
- backpointers = gen_array_ops.reverse_sequence(
- backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
-
- # Computes backward decoding. Extract tag indices from backpointers.
- crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
- initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
- dtype=dtypes.int32) # [B]
- initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
- decode_tags, _ = rnn.dynamic_rnn(
- crf_bwd_cell,
- inputs=backpointers,
- sequence_length=sequence_length - 1,
- initial_state=initial_state,
- time_major=False,
- dtype=dtypes.int32) # [B, T - 1, 1]
- decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
- decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
- decode_tags = gen_array_ops.reverse_sequence(
- decode_tags, sequence_length, seq_dim=1) # [B, T]
-
- best_score = math_ops.reduce_max(last_score, axis=1) # [B]
- return decode_tags, best_score
+ # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag
+ # and the max activation.
+ def _single_seq_fn():
+ squeezed_potentials = array_ops.squeeze(potentials, [1])
+ decode_tags = array_ops.expand_dims(
+ math_ops.argmax(squeezed_potentials, axis=1), 1)
+ best_score = math_ops.reduce_max(squeezed_potentials, axis=1)
+ return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score
+
+ def _multi_seq_fn():
+ """Decoding of highest scoring sequence."""
+
+ # For simplicity, in shape comments, denote:
+ # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
+ num_tags = potentials.get_shape()[2].value
+
+ # Computes forward decoding. Get last score and backpointers.
+ crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
+ initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
+ initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
+ inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
+ backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O]
+ crf_fwd_cell,
+ inputs=inputs,
+ sequence_length=sequence_length - 1,
+ initial_state=initial_state,
+ time_major=False,
+ dtype=dtypes.int32)
+ backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O]
+ backpointers, sequence_length - 1, seq_dim=1)
+
+ # Computes backward decoding. Extract tag indices from backpointers.
+ crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
+ initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B]
+ dtype=dtypes.int32)
+ initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
+ decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1]
+ crf_bwd_cell,
+ inputs=backpointers,
+ sequence_length=sequence_length - 1,
+ initial_state=initial_state,
+ time_major=False,
+ dtype=dtypes.int32)
+ decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
+ decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T]
+ axis=1)
+ decode_tags = gen_array_ops.reverse_sequence( # [B, T]
+ decode_tags, sequence_length, seq_dim=1)
+
+ best_score = math_ops.reduce_max(last_score, axis=1) # [B]
+ return decode_tags, best_score
+
+ return utils.smart_cond(
+ pred=math_ops.equal(
+ potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
+ fn1=_single_seq_fn,
+ fn2=_multi_seq_fn)