aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf/python/ops/crf.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/crf/python/ops/crf.py')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 7166e38b28..c8adb0369b 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -360,8 +360,8 @@ class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
scope: Unused variable scope of this cell.
Returns:
- backpointers: [batch_size, num_tags], containing backpointers.
- new_state: [batch_size, num_tags], containing new score values.
+ backpointers: A [batch_size, num_tags] matrix of backpointers.
+ new_state: A [batch_size, num_tags] matrix of new score values.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
@@ -385,7 +385,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
- num_tags
+ num_tags: An integer.
"""
self._num_tags = num_tags
@@ -401,8 +401,9 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Build the CrfDecodeBackwardRnnCell.
Args:
- inputs: [batch_size, num_tags], backpointer of next step (in time order).
- state: [batch_size, 1], next position's tag index.
+ inputs: A [batch_size, num_tags] matrix of
+ backpointer of next step (in time order).
+ state: A [batch_size, 1] matrix of tag index of next step.
scope: Unused variable scope of this cell.
Returns:
@@ -426,16 +427,16 @@ def crf_decode(potentials, transition_params, sequence_length):
This is a function for tensor.
Args:
- potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
+ potentials: A [batch_size, max_seq_len, num_tags] tensor of
unary potentials.
- transition_params: A [num_tags, num_tags] tensor, matrix of
+ transition_params: A [num_tags, num_tags] matrix of
binary potentials.
- sequence_length: A [batch_size] tensor, containing sequence lengths.
+ sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
- decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
+ decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indicies.
- best_score: A [batch_size] tensor, containing the score of decode_tags.
+ best_score: A [batch_size] vector, containing the score of `decode_tags`.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).