aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-22 12:34:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-22 13:49:33 -0700
commit34f7db27b8e75009cb2f131ca5e1bf7f8dc88cf2 (patch)
tree90be260ac962b31f728fc7292c2d3ac669dfddf1 /tensorflow/contrib/crf
parentdffe0fa45bcca5ecd2e0513193921deb5fb5eedf (diff)
Adds a linear-chain CRF layer to tensorflow/contrib.
This CRF is factored into unary scores for every tag from the layer below and transition scores between tags for adjacent words. For any tagging task given word sequences, this can simply replace the softmax layer. Optimizing the negative of tf.contrib.crf.crf_log_likelihood will maximize the probability of the given tag sequence. This module also contains a Viterbi decoder that can be used to extract the sequence with the maximum likelihood at test time outside of TensorFlow. Change: 133993074
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/BUILD40
-rw-r--r--tensorflow/contrib/crf/README.md76
-rw-r--r--tensorflow/contrib/crf/__init__.py39
-rw-r--r--tensorflow/contrib/crf/python/__init__.py18
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py200
-rw-r--r--tensorflow/contrib/crf/python/ops/__init__.py18
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py311
7 files changed, 702 insertions, 0 deletions
diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD
new file mode 100644
index 0000000000..33c1323b48
--- /dev/null
+++ b/tensorflow/contrib/crf/BUILD
@@ -0,0 +1,40 @@
+# Description:
+# Contains classes to construct a CRF layer
+# APIs here are meant to evolve over time.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+
+py_library(
+ name = "crf_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+)
+
+cuda_py_tests(
+ name = "crf_test",
+ srcs = ["python/kernel_tests/crf_test.py"],
+ additional_deps = [
+ ":crf_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/crf/README.md b/tensorflow/contrib/crf/README.md
new file mode 100644
index 0000000000..68d1101ecd
--- /dev/null
+++ b/tensorflow/contrib/crf/README.md
@@ -0,0 +1,76 @@
+# CRF
+
+The CRF module implements a linear-chain CRF layer for learning to predict tag sequences. This variant of the CRF is factored into unary potentials for every element in the sequence and binary potentials for every transition between output tags.
+
+### Usage
+
+Below is an example of the API, which learns a CRF for some random data. The linear layer in the example can be replaced by any neural network.
+
+
+```python
+import numpy as np
+import tensorflow as tf
+
+# Data settings.
+num_examples = 10
+num_words = 20
+num_features = 100
+num_tags = 5
+
+# Random features.
+x = np.random.rand(num_examples, num_words, num_features).astype(np.float32)
+
+# Random tag indices representing the gold sequence.
+y = np.random.randint(num_tags, size=[num_examples, num_words]).astype(np.int32)
+
+# All sequences in this example have the same length, but they can be variable in a real model.
+sequence_lengths = np.full(num_examples, num_words - 1, dtype=np.int32)
+
+# Train and evaluate the model.
+with tf.Graph().as_default():
+ with tf.Session() as session:
+ # Add the data to the TensorFlow graph.
+ x_t = tf.constant(x)
+ y_t = tf.constant(y)
+ sequence_lengths_t = tf.constant(sequence_lengths)
+
+ # Compute unary scores from a linear layer.
+ weights = tf.get_variable("weights", [num_features, num_tags])
+ matricized_x_t = tf.reshape(x_t, [-1, num_features])
+ matricized_unary_scores = tf.batch_matmul(matricized_x_t, weights)
+ unary_scores = tf.reshape(matricized_unary_scores,
+ [num_examples, num_words, num_tags])
+
+ # Compute the log-likelihood of the gold sequences and keep the transition
+ # params for inference at test time.
+ log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
+ unary_scores, y_t, sequence_lengths_t)
+
+ # Add a training op to tune the parameters.
+ loss = tf.reduce_mean(-log_likelihood)
+ train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
+
+ # Train for a fixed number of iterations.
+ session.run(tf.initialize_all_variables())
+ for i in range(1000):
+ tf_unary_scores, tf_transition_params, _ = session.run(
+ [unary_scores, transition_params, train_op])
+ if i % 100 == 0:
+ correct_labels = 0
+ total_labels = 0
+ for tf_unary_scores_, y_, sequence_length_ in zip(tf_unary_scores, y,
+ sequence_lengths):
+ # Remove padding from the scores and tag sequence.
+ tf_unary_scores_ = tf_unary_scores_[:sequence_length_]
+ y_ = y_[:sequence_length_]
+
+ # Compute the highest scoring sequence.
+ viterbi_sequence, _ = tf.contrib.crf.viterbi_decode(
+ tf_unary_scores_, tf_transition_params)
+
+ # Evaluate word-level accuracy.
+ correct_labels += np.sum(np.equal(viterbi_sequence, y_))
+ total_labels += sequence_length_
+ accuracy = 100.0 * correct_labels / float(total_labels)
+ print("Accuracy: %.2f%%" % accuracy)
+```
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
new file mode 100644
index 0000000000..195e8cd717
--- /dev/null
+++ b/tensorflow/contrib/crf/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Linear-chain CRF layer.
+
+## This package provides functions for building a linear-chain CRF layer.
+
+@@crf_sequence_score
+@@crf_log_norm
+@@crf_log_likelihood
+@@crf_unary_score
+@@crf_binary_score
+@@CrfForwardRnnCell
+@@viterbi_decode
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks
+from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
+from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
+from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
+from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
+from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
+from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import viterbi_decode
diff --git a/tensorflow/contrib/crf/python/__init__.py b/tensorflow/contrib/crf/python/__init__.py
new file mode 100644
index 0000000000..8439848dd0
--- /dev/null
+++ b/tensorflow/contrib/crf/python/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Linear-chain CRF."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
new file mode 100644
index 0000000000..539cabe620
--- /dev/null
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -0,0 +1,200 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for CRF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+import tensorflow as tf
+
+
+class CrfTest(tf.test.TestCase):
+
+ def testCrfSequenceScore(self):
+ inputs = np.array(
+ [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
+ tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ sequence_lengths = np.array(3, dtype=np.int32)
+ with self.test_session() as sess:
+ sequence_score = tf.contrib.crf.crf_sequence_score(
+ inputs=tf.expand_dims(inputs, 0),
+ tag_indices=tf.expand_dims(tag_indices, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ transition_params=tf.constant(transition_params))
+ sequence_score = tf.squeeze(sequence_score, [0])
+ tf_sequence_score = sess.run(sequence_score)
+ expected_unary_score = sum(inputs[i][tag_indices[i]]
+ for i in range(sequence_lengths))
+ expected_binary_score = sum(
+ transition_params[tag_indices[i], tag_indices[i + 1]]
+ for i in range(sequence_lengths - 1))
+ expected_sequence_score = expected_unary_score + expected_binary_score
+ self.assertAllClose(tf_sequence_score, expected_sequence_score)
+
+ def testCrfUnaryScore(self):
+ inputs = np.array(
+ [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
+ tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
+ sequence_lengths = np.array(3, dtype=np.int32)
+ with self.test_session() as sess:
+ unary_score = tf.contrib.crf.crf_unary_score(
+ tag_indices=tf.expand_dims(tag_indices, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ inputs=tf.expand_dims(inputs, 0))
+ unary_score = tf.squeeze(unary_score, [0])
+ tf_unary_score = sess.run(unary_score)
+ expected_unary_score = sum(inputs[i][tag_indices[i]]
+ for i in range(sequence_lengths))
+ self.assertAllClose(tf_unary_score, expected_unary_score)
+
+ def testCrfBinaryScore(self):
+ tag_indices = np.array([1, 2, 1, 0], dtype=np.int32)
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ sequence_lengths = np.array(3, dtype=np.int32)
+ with self.test_session() as sess:
+ binary_score = tf.contrib.crf.crf_binary_score(
+ tag_indices=tf.expand_dims(tag_indices, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ transition_params=tf.constant(transition_params))
+ binary_score = tf.squeeze(binary_score, [0])
+ tf_binary_score = sess.run(binary_score)
+ expected_binary_score = sum(
+ transition_params[tag_indices[i], tag_indices[i + 1]]
+ for i in range(sequence_lengths - 1))
+ self.assertAllClose(tf_binary_score, expected_binary_score)
+
+ def testCrfLogNorm(self):
+ inputs = np.array(
+ [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ num_words = inputs.shape[0]
+ num_tags = inputs.shape[1]
+ sequence_lengths = np.array(3, dtype=np.int32)
+ with self.test_session() as sess:
+ all_sequence_scores = []
+
+ # Compare the dynamic program with brute force computation.
+ for tag_indices in itertools.product(
+ range(num_tags), repeat=sequence_lengths):
+ tag_indices = list(tag_indices)
+ tag_indices.extend([0] * (num_words - sequence_lengths))
+ all_sequence_scores.append(
+ tf.contrib.crf.crf_sequence_score(
+ inputs=tf.expand_dims(inputs, 0),
+ tag_indices=tf.expand_dims(tag_indices, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ transition_params=tf.constant(transition_params)))
+
+ brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores)
+ log_norm = tf.contrib.crf.crf_log_norm(
+ inputs=tf.expand_dims(inputs, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ transition_params=tf.constant(transition_params))
+ log_norm = tf.squeeze(log_norm, [0])
+ tf_brute_force_log_norm, tf_log_norm = sess.run(
+ [brute_force_log_norm, log_norm])
+
+ self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
+
+ def testCrfLogLikelihood(self):
+ inputs = np.array(
+ [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ sequence_lengths = np.array(3, dtype=np.int32)
+ num_words = inputs.shape[0]
+ num_tags = inputs.shape[1]
+ with self.test_session() as sess:
+ all_sequence_log_likelihoods = []
+
+ # Make sure all probabilities sum to 1.
+ for tag_indices in itertools.product(
+ range(num_tags), repeat=sequence_lengths):
+ tag_indices = list(tag_indices)
+ tag_indices.extend([0] * (num_words - sequence_lengths))
+ sequence_log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(
+ inputs=tf.expand_dims(inputs, 0),
+ tag_indices=tf.expand_dims(tag_indices, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ transition_params=tf.constant(transition_params))
+ all_sequence_log_likelihoods.append(sequence_log_likelihood)
+ total_log_likelihood = tf.reduce_logsumexp(all_sequence_log_likelihoods)
+ tf_total_log_likelihood = sess.run(total_log_likelihood)
+ self.assertAllClose(tf_total_log_likelihood, 0.0)
+
+ def testLengthsToMasks(self):
+ with self.test_session() as sess:
+ sequence_lengths = [4, 1, 8, 2]
+ max_sequence_length = max(sequence_lengths)
+
+ mask = tf.contrib.crf._lengths_to_masks(sequence_lengths,
+ max_sequence_length)
+ tf_mask = sess.run(mask)
+ self.assertEqual(len(tf_mask), len(sequence_lengths))
+ for m, l in zip(tf_mask, sequence_lengths):
+ self.assertAllEqual(m[:l], [1] * l)
+ self.assertAllEqual(m[l:], [0] * (len(m) - l))
+
+ def testViterbiDecode(self):
+ inputs = np.array(
+ [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
+ transition_params = np.array(
+ [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
+ sequence_lengths = np.array(3, dtype=np.int32)
+ num_words = inputs.shape[0]
+ num_tags = inputs.shape[1]
+
+ with self.test_session() as sess:
+ all_sequence_scores = []
+ all_sequences = []
+
+ # Compare the dynamic program with brute force computation.
+ for tag_indices in itertools.product(
+ range(num_tags), repeat=sequence_lengths):
+ tag_indices = list(tag_indices)
+ tag_indices.extend([0] * (num_words - sequence_lengths))
+ all_sequences.append(tag_indices)
+ sequence_score = tf.contrib.crf.crf_sequence_score(
+ inputs=tf.expand_dims(inputs, 0),
+ tag_indices=tf.expand_dims(tag_indices, 0),
+ sequence_lengths=tf.expand_dims(sequence_lengths, 0),
+ transition_params=tf.constant(transition_params))
+ sequence_score = tf.squeeze(sequence_score, [0])
+ all_sequence_scores.append(sequence_score)
+
+ tf_all_sequence_scores = sess.run(all_sequence_scores)
+
+ expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
+ expected_max_sequence = all_sequences[expected_max_sequence_index]
+ expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
+
+ actual_max_sequence, actual_max_score = tf.contrib.crf.viterbi_decode(
+ inputs[:sequence_lengths], transition_params)
+
+ self.assertAllClose(actual_max_score, expected_max_score)
+ self.assertEqual(actual_max_sequence,
+ expected_max_sequence[:sequence_lengths])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/crf/python/ops/__init__.py b/tensorflow/contrib/crf/python/ops/__init__.py
new file mode 100644
index 0000000000..5ab8d7ac4a
--- /dev/null
+++ b/tensorflow/contrib/crf/python/ops/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for building a linear-chain CRF layer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
new file mode 100644
index 0000000000..fbbbc2d5c1
--- /dev/null
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -0,0 +1,311 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module for constructing a linear-chain CRF.
+
+The following snippet is an example of a CRF layer on top of a batched sequence
+of unary scores (logits for every word). This example also decodes the most
+likely sequence at test time:
+
+log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
+ unary_scores, gold_tags, sequence_lengths)
+loss = tf.reduce_mean(-log_likelihood)
+train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
+
+tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
+ [unary_scores, sequence_lengths, transition_params, train_op])
+for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
+ tf_sequence_lengths):
+# Remove padding.
+tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
+
+# Compute the highest score and its tag sequence.
+viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
+ tf_unary_scores_, tf_transition_params)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import variable_scope as vs
+
+__all__ = ["crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
+ "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
+ "viterbi_decode"]
+
+
+def _lengths_to_masks(lengths, max_length):
+ """Creates a binary matrix that can be used to mask away padding.
+
+ Args:
+ lengths: A vector of integers representing lengths.
+ max_length: An integer indicating the maximum length. All values in
+ lengths should be less than max_length.
+ Returns:
+ masks: Masks that can be used to get rid of padding.
+ """
+ tiled_ranges = array_ops.tile(
+ array_ops.expand_dims(math_ops.range(max_length), 0),
+ [array_ops.shape(lengths)[0], 1])
+ lengths = array_ops.expand_dims(lengths, 1)
+ masks = math_ops.to_float(
+ math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths))
+ return masks
+
+
+def crf_sequence_score(inputs, tag_indices, sequence_lengths,
+ transition_params):
+ """Computes the unnormalized score for a tag sequence.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
+ compute the unnormalized score.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] transition matrix.
+ Returns:
+ sequence_scores: A [batch_size] vector of unnormalized sequence scores.
+ """
+ # Compute the scores of the given tag sequence.
+ unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
+ binary_scores = crf_binary_score(tag_indices, sequence_lengths,
+ transition_params)
+ sequence_scores = unary_scores + binary_scores
+ return sequence_scores
+
+
+def crf_log_norm(inputs, sequence_lengths, transition_params):
+ """Computes the normalization for a CRF.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] transition matrix.
+ Returns:
+ log_norm: A [batch_size] vector of normalizers for a CRF.
+ """
+ # Split up the first and rest of the inputs in preparation for the forward
+ # algorithm.
+ first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
+ first_input = array_ops.squeeze(first_input, [1])
+ rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])
+
+ # Compute the alpha values in the forward algorithm in order to get the
+ # partition function.
+ forward_cell = CrfForwardRnnCell(transition_params)
+ _, alphas = rnn.dynamic_rnn(
+ cell=forward_cell,
+ inputs=rest_of_input,
+ sequence_length=sequence_lengths - 1,
+ initial_state=first_input,
+ dtype=dtypes.float32)
+ log_norm = math_ops.reduce_logsumexp(alphas, [1])
+ return log_norm
+
+
+def crf_log_likelihood(inputs,
+ tag_indices,
+ sequence_lengths,
+ transition_params=None):
+ """Computes the log-likehood of tag sequences in a CRF.
+
+ Args:
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
+ to use as input to the CRF layer.
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
+ compute the log-likehood.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] transition matrix, if available.
+ Returns:
+ log_likelihood: A scalar containing the log-likelihood of the given sequence
+ of tag indices.
+ transition_params: A [num_tags, num_tags] transition matrix. This is either
+ provided by the caller or created in this function.
+ """
+ # Get shape information.
+ num_tags = inputs.get_shape()[2].value
+
+ # Get the transition matrix if not provided.
+ if transition_params is None:
+ transition_params = vs.get_variable("transitions", [num_tags, num_tags])
+
+ sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
+ transition_params)
+ log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
+
+ # Normalize the scores to get the log-likelihood.
+ log_likelihood = sequence_scores - log_norm
+ return log_likelihood, transition_params
+
+
+def crf_unary_score(tag_indices, sequence_lengths, inputs):
+ """Computes the unary scores of tag sequences.
+
+ Args:
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials.
+ Returns:
+ unary_scores: A [batch_size] vector of unary scores.
+ """
+ batch_size = array_ops.shape(inputs)[0]
+ max_seq_len = array_ops.shape(inputs)[1]
+ num_tags = array_ops.shape(inputs)[2]
+
+ flattened_inputs = array_ops.reshape(inputs, [-1])
+
+ offsets = array_ops.expand_dims(
+ math_ops.range(batch_size) * max_seq_len * num_tags, 1)
+ offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0)
+ flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1])
+
+ unary_scores = array_ops.reshape(
+ array_ops.gather(flattened_inputs, flattened_tag_indices),
+ [batch_size, max_seq_len])
+
+ masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
+
+ unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
+ return unary_scores
+
+
+def crf_binary_score(tag_indices, sequence_lengths, transition_params):
+ """Computes the binary scores of tag sequences.
+
+ Args:
+ tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
+ sequence_lengths: A [batch_size] vector of true sequence lengths.
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+ Returns:
+ binary_scores: A [batch_size] vector of binary scores.
+ """
+ # Get shape information.
+ num_tags = transition_params.get_shape()[0]
+ num_transitions = array_ops.shape(tag_indices)[1] - 1
+
+ # Truncate by one on each side of the sequence to get the start and end
+ # indices of each transition.
+ start_tag_indices = array_ops.slice(tag_indices, [0, 0],
+ [-1, num_transitions])
+ end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
+
+ # Encode the indices in a flattened representation.
+ flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
+ flattened_transition_params = array_ops.reshape(transition_params, [-1])
+
+ # Get the binary scores based on the flattened representation.
+ binary_scores = array_ops.gather(flattened_transition_params,
+ flattened_transition_indices)
+
+ masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
+ truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
+ binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
+ return binary_scores
+
+
+class CrfForwardRnnCell(rnn_cell.RNNCell):
+ """Computes the alpha values in a linear-chain CRF.
+
+ See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
+ """
+
+ def __init__(self, transition_params):
+ """Initialize the CrfForwardRnnCell.
+
+ Args:
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+ This matrix is expanded into a [1, num_tags, num_tags] in preparation
+ for the broadcast summation occurring within the cell.
+ """
+ self._transition_params = array_ops.expand_dims(transition_params, 0)
+ self._num_tags = transition_params.get_shape()[0].value
+
+ @property
+ def state_size(self):
+ return self._num_tags
+
+ @property
+ def output_size(self):
+ return self._num_tags
+
+ def __call__(self, inputs, state, scope=None):
+ """Build the CrfForwardRnnCell.
+
+ Args:
+ inputs: A [batch_size, num_tags] matrix of unary potentials.
+ state: A [batch_size, num_tags] matrix containing the previous alpha
+ values.
+ scope: Unused variable scope of this cell.
+
+ Returns:
+ new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
+ values containing the new alpha values.
+ """
+ state = array_ops.expand_dims(state, 2)
+
+ # This addition op broadcasts self._transitions_params along the zeroth
+ # dimension and state along the second dimension. This performs the
+ # multiplication of previous alpha values and the current binary potentials
+ # in log space.
+ transition_scores = state + self._transition_params
+ new_alphas = inputs + math_ops.reduce_logsumexp(transition_scores, [1])
+
+ # Both the state and the output of this RNN cell contain the alphas values.
+ # The output value is currently unused and simply satisfies the RNN API.
+ # This could be useful in the future if we need to compute marginal
+ # probabilities, which would require the accumulated alpha values at every
+ # time step.
+ return new_alphas, new_alphas
+
+
+def viterbi_decode(score, transition_params):
+ """Decode the highest scoring sequence of tags outside of TensorFlow.
+
+ This should only be used at test time.
+
+ Args:
+ score: A [seq_len, num_tags] matrix of unary potentials.
+ transition_params: A [num_tags, num_tags] matrix of binary potentials.
+
+ Returns:
+ viterbi: A [seq_len] list of integers containing the highest scoring tag
+ indicies.
+ viterbi_score: A float containing the score for the viterbi sequence.
+ """
+ trellis = np.zeros_like(score)
+ backpointers = np.zeros_like(score, dtype=np.int32)
+ trellis[0] = score[0]
+
+ for t in range(1, score.shape[0]):
+ v = np.expand_dims(trellis[t - 1], 1) + transition_params
+ trellis[t] = score[t] + np.max(v, 0)
+ backpointers[t] = np.argmax(v, 0)
+
+ viterbi = [np.argmax(trellis[-1])]
+ for bp in reversed(backpointers[1:]):
+ viterbi.append(bp[viterbi[-1]])
+ viterbi.reverse()
+
+ viterbi_score = np.max(trellis[-1])
+ return viterbi, viterbi_score