diff options
author | 2017-12-22 15:50:19 -0800 | |
---|---|---|
committer | 2017-12-22 15:54:28 -0800 | |
commit | d4091eec522e41093e6e10601af79c75bee14c80 (patch) | |
tree | e7f2024d99f30d3cbe9616ad558a6a04c2cf1aa0 /tensorflow/contrib/crf | |
parent | a64485dbb378d7ac6afc9082fd7176a957815a8c (diff) |
Replaces custom _lengths_to_masks function with the official, more efficient sequence_mask function that supersedes it.
PiperOrigin-RevId: 179971521
Diffstat (limited to 'tensorflow/contrib/crf')
-rw-r--r-- | tensorflow/contrib/crf/__init__.py | 17 | ||||
-rw-r--r-- | tensorflow/contrib/crf/python/kernel_tests/crf_test.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/crf/python/ops/crf.py | 27 |
3 files changed, 14 insertions, 41 deletions
diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index bc749339bd..046c509626 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -16,15 +16,15 @@ See the @{$python/contrib.crf} guide. -@@crf_sequence_score -@@crf_log_norm -@@crf_log_likelihood -@@crf_unary_score @@crf_binary_score @@crf_decode -@@CrfForwardRnnCell -@@CrfDecodeForwardRnnCell +@@crf_log_likelihood +@@crf_log_norm +@@crf_sequence_score +@@crf_unary_score @@CrfDecodeBackwardRnnCell +@@CrfDecodeForwardRnnCell +@@CrfForwardRnnCell @@viterbi_decode """ @@ -32,16 +32,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks from tensorflow.contrib.crf.python.ops.crf import crf_binary_score from tensorflow.contrib.crf.python.ops.crf import crf_decode from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood from tensorflow.contrib.crf.python.ops.crf import crf_log_norm from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score -from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell -from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import viterbi_decode from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index b47fb426a1..721dc4d080 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -179,17 +179,6 @@ class CrfTest(test.TestCase): tf_total_log_likelihood = sess.run(total_log_likelihood) self.assertAllClose(tf_total_log_likelihood, 0.0) - def testLengthsToMasks(self): - with self.test_session() as sess: - sequence_lengths = [4, 1, 8, 2] - max_sequence_length = max(sequence_lengths) - mask = crf._lengths_to_masks(sequence_lengths, max_sequence_length) - tf_mask = sess.run(mask) - self.assertEqual(len(tf_mask), len(sequence_lengths)) - for m, l in zip(tf_mask, sequence_lengths): - self.assertAllEqual(m[:l], [1] * l) - self.assertAllEqual(m[l:], [0] * (len(m) - l)) - def testViterbiDecode(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 7f5ae937b2..62708636c6 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -70,25 +70,6 @@ __all__ = [ ] -def _lengths_to_masks(lengths, max_length): - """Creates a binary matrix that can be used to mask away padding. - - Args: - lengths: A vector of integers representing lengths. - max_length: An integer indicating the maximum length. All values in - lengths should be less than max_length. - Returns: - masks: Masks that can be used to get rid of padding. - """ - tiled_ranges = array_ops.tile( - array_ops.expand_dims(math_ops.range(max_length), 0), - [array_ops.shape(lengths)[0], 1]) - lengths = array_ops.expand_dims(lengths, 1) - masks = math_ops.to_float( - math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths)) - return masks - - def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -234,7 +215,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): array_ops.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) unary_scores = math_ops.reduce_sum(unary_scores * masks, 1) return unary_scores @@ -268,7 +251,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): binary_scores = array_ops.gather(flattened_transition_params, flattened_transition_indices) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1]) binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1) return binary_scores |