aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-22 15:50:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 15:54:28 -0800
commitd4091eec522e41093e6e10601af79c75bee14c80 (patch)
treee7f2024d99f30d3cbe9616ad558a6a04c2cf1aa0 /tensorflow/contrib/crf
parenta64485dbb378d7ac6afc9082fd7176a957815a8c (diff)
Replaces custom _lengths_to_masks function with the official, more efficient sequence_mask function that supersedes it.
PiperOrigin-RevId: 179971521
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r--tensorflow/contrib/crf/__init__.py17
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py11
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py27
3 files changed, 14 insertions, 41 deletions
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py
index bc749339bd..046c509626 100644
--- a/tensorflow/contrib/crf/__init__.py
+++ b/tensorflow/contrib/crf/__init__.py
@@ -16,15 +16,15 @@
See the @{$python/contrib.crf} guide.
-@@crf_sequence_score
-@@crf_log_norm
-@@crf_log_likelihood
-@@crf_unary_score
@@crf_binary_score
@@crf_decode
-@@CrfForwardRnnCell
-@@CrfDecodeForwardRnnCell
+@@crf_log_likelihood
+@@crf_log_norm
+@@crf_sequence_score
+@@crf_unary_score
@@CrfDecodeBackwardRnnCell
+@@CrfDecodeForwardRnnCell
+@@CrfForwardRnnCell
@@viterbi_decode
"""
@@ -32,16 +32,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks
from tensorflow.contrib.crf.python.ops.crf import crf_binary_score
from tensorflow.contrib.crf.python.ops.crf import crf_decode
from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood
from tensorflow.contrib.crf.python.ops.crf import crf_log_norm
from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
-from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
-from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell
from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell
+from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
from tensorflow.contrib.crf.python.ops.crf import viterbi_decode
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index b47fb426a1..721dc4d080 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -179,17 +179,6 @@ class CrfTest(test.TestCase):
tf_total_log_likelihood = sess.run(total_log_likelihood)
self.assertAllClose(tf_total_log_likelihood, 0.0)
- def testLengthsToMasks(self):
- with self.test_session() as sess:
- sequence_lengths = [4, 1, 8, 2]
- max_sequence_length = max(sequence_lengths)
- mask = crf._lengths_to_masks(sequence_lengths, max_sequence_length)
- tf_mask = sess.run(mask)
- self.assertEqual(len(tf_mask), len(sequence_lengths))
- for m, l in zip(tf_mask, sequence_lengths):
- self.assertAllEqual(m[:l], [1] * l)
- self.assertAllEqual(m[l:], [0] * (len(m) - l))
-
def testViterbiDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 7f5ae937b2..62708636c6 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -70,25 +70,6 @@ __all__ = [
]
-def _lengths_to_masks(lengths, max_length):
- """Creates a binary matrix that can be used to mask away padding.
-
- Args:
- lengths: A vector of integers representing lengths.
- max_length: An integer indicating the maximum length. All values in
- lengths should be less than max_length.
- Returns:
- masks: Masks that can be used to get rid of padding.
- """
- tiled_ranges = array_ops.tile(
- array_ops.expand_dims(math_ops.range(max_length), 0),
- [array_ops.shape(lengths)[0], 1])
- lengths = array_ops.expand_dims(lengths, 1)
- masks = math_ops.to_float(
- math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths))
- return masks
-
-
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params):
"""Computes the unnormalized score for a tag sequence.
@@ -234,7 +215,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs):
array_ops.gather(flattened_inputs, flattened_tag_indices),
[batch_size, max_seq_len])
- masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
+ masks = array_ops.sequence_mask(sequence_lengths,
+ maxlen=array_ops.shape(tag_indices)[1],
+ dtype=dtypes.float32)
unary_scores = math_ops.reduce_sum(unary_scores * masks, 1)
return unary_scores
@@ -268,7 +251,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params):
binary_scores = array_ops.gather(flattened_transition_params,
flattened_transition_indices)
- masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
+ masks = array_ops.sequence_mask(sequence_lengths,
+ maxlen=array_ops.shape(tag_indices)[1],
+ dtype=dtypes.float32)
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
return binary_scores