aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf/python/ops/crf.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/crf/python/ops/crf.py')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py171
1 files changed, 168 insertions, 3 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index a19c70717a..7166e38b28 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -16,13 +16,24 @@
The following snippet is an example of a CRF layer on top of a batched sequence
of unary scores (logits for every word). This example also decodes the most
-likely sequence at test time:
+likely sequence at test time. There are two ways to do decoding. One
+is using crf_decode to do decoding in Tensorflow , and the other one is using
+viterbi_decode in Numpy.
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, gold_tags, sequence_lengths)
+
loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
+# Decoding in Tensorflow.
+viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
+ unary_scores, transition_params, sequence_lengths)
+
+tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
+ [viterbi_sequence, viterbi_score, train_op])
+
+# Decoding in Numpy.
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
[unary_scores, sequence_lengths, transition_params, train_op])
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
@@ -31,7 +42,7 @@ for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
# Compute the highest score and its tag sequence.
-viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
+tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params)
"""
@@ -43,6 +54,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
@@ -50,7 +62,9 @@ from tensorflow.python.ops import variable_scope as vs
__all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
- "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode"
+ "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
+ "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
+ "CrfDecodeBackwardRnnCell"
]
@@ -310,3 +324,154 @@ def viterbi_decode(score, transition_params):
viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score
+
+
+class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
+ """Computes the forward decoding in a linear-chain CRF.
+ """
+
+ def __init__(self, transition_params):
+ """Initialize the CrfDecodeForwardRnnCell.
+
+ Args:
+ transition_params: A [num_tags, num_tags] matrix of binary
+ potentials. This matrix is expanded into a
+ [1, num_tags, num_tags] in preparation for the broadcast
+ summation occurring within the cell.
+ """
+ self._transition_params = array_ops.expand_dims(transition_params, 0)
+ self._num_tags = transition_params.get_shape()[0].value
+
+ @property
+ def state_size(self):
+ return self._num_tags
+
+ @property
+ def output_size(self):
+ return self._num_tags
+
+ def __call__(self, inputs, state, scope=None):
+ """Build the CrfDecodeForwardRnnCell.
+
+ Args:
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
+ state: A [batch_size, num_tags] matrix containing the previous step's
+ score values.
+ scope: Unused variable scope of this cell.
+
+ Returns:
+ backpointers: [batch_size, num_tags], containing backpointers.
+ new_state: [batch_size, num_tags], containing new score values.
+ """
+ # For simplicity, in shape comments, denote:
+ # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
+ state = array_ops.expand_dims(state, 2) # [B, O, 1]
+
+ # This addition op broadcasts self._transitions_params along the zeroth
+ # dimension and state along the second dimension.
+ # [B, O, 1] + [1, O, O] -> [B, O, O]
+ transition_scores = state + self._transition_params # [B, O, O]
+ new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
+ backpointers = math_ops.argmax(transition_scores, 1)
+ backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
+ return backpointers, new_state
+
+
+class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
+ """Computes backward decoding in a linear-chain CRF.
+ """
+
+ def __init__(self, num_tags):
+ """Initialize the CrfDecodeBackwardRnnCell.
+
+ Args:
+ num_tags
+ """
+ self._num_tags = num_tags
+
+ @property
+ def state_size(self):
+ return 1
+
+ @property
+ def output_size(self):
+ return 1
+
+ def __call__(self, inputs, state, scope=None):
+ """Build the CrfDecodeBackwardRnnCell.
+
+ Args:
+ inputs: [batch_size, num_tags], backpointer of next step (in time order).
+ state: [batch_size, 1], next position's tag index.
+ scope: Unused variable scope of this cell.
+
+ Returns:
+ new_tags, new_tags: A pair of [batch_size, num_tags]
+ tensors containing the new tag indices.
+ """
+ state = array_ops.squeeze(state, axis=[1]) # [B]
+ batch_size = array_ops.shape(inputs)[0]
+ b_indices = math_ops.range(batch_size) # [B]
+ indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
+ new_tags = array_ops.expand_dims(
+ gen_array_ops.gather_nd(inputs, indices), # [B]
+ axis=-1) # [B, 1]
+
+ return new_tags, new_tags
+
+
+def crf_decode(potentials, transition_params, sequence_length):
+ """Decode the highest scoring sequence of tags in TensorFlow.
+
+ This is a function for tensor.
+
+ Args:
+ potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
+ unary potentials.
+ transition_params: A [num_tags, num_tags] tensor, matrix of
+ binary potentials.
+ sequence_length: A [batch_size] tensor, containing sequence lengths.
+
+ Returns:
+ decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
+ Contains the highest scoring tag indicies.
+ best_score: A [batch_size] tensor, containing the score of decode_tags.
+ """
+ # For simplicity, in shape comments, denote:
+ # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
+ num_tags = potentials.get_shape()[2].value
+
+ # Computes forward decoding. Get last score and backpointers.
+ crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
+ initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
+ initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
+ inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
+ backpointers, last_score = rnn.dynamic_rnn(
+ crf_fwd_cell,
+ inputs=inputs,
+ sequence_length=sequence_length - 1,
+ initial_state=initial_state,
+ time_major=False,
+ dtype=dtypes.int32) # [B, T - 1, O], [B, O]
+ backpointers = gen_array_ops.reverse_sequence(
+ backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
+
+ # Computes backward decoding. Extract tag indices from backpointers.
+ crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
+ initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
+ dtype=dtypes.int32) # [B]
+ initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
+ decode_tags, _ = rnn.dynamic_rnn(
+ crf_bwd_cell,
+ inputs=backpointers,
+ sequence_length=sequence_length - 1,
+ initial_state=initial_state,
+ time_major=False,
+ dtype=dtypes.int32) # [B, T - 1, 1]
+ decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
+ decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
+ decode_tags = gen_array_ops.reverse_sequence(
+ decode_tags, sequence_length, seq_dim=1) # [B, T]
+
+ best_score = math_ops.reduce_max(last_score, axis=1) # [B]
+ return decode_tags, best_score