aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/crf/python/ops/crf.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/crf/python/ops/crf.py')
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 1612c75179..4282be5ec8 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -363,8 +363,8 @@ class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
scope: Unused variable scope of this cell.
Returns:
- backpointers: A [batch_size, num_tags] matrix of backpointers.
- new_state: A [batch_size, num_tags] matrix of new score values.
+ backpointers: [batch_size, num_tags], containing backpointers.
+ new_state: [batch_size, num_tags], containing new score values.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
@@ -404,9 +404,8 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Build the CrfDecodeBackwardRnnCell.
Args:
- inputs: A [batch_size, num_tags] matrix of
- backpointer of next step (in time order).
- state: A [batch_size, 1] matrix of tag index of next step.
+ inputs: [batch_size, num_tags], backpointer of next step (in time order).
+ state: [batch_size, 1], next position's tag index.
scope: Unused variable scope of this cell.
Returns:
@@ -430,16 +429,16 @@ def crf_decode(potentials, transition_params, sequence_length):
This is a function for tensor.
Args:
- potentials: A [batch_size, max_seq_len, num_tags] tensor of
+ potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
unary potentials.
- transition_params: A [num_tags, num_tags] matrix of
+ transition_params: A [num_tags, num_tags] tensor, matrix of
binary potentials.
- sequence_length: A [batch_size] vector of true sequence lengths.
+ sequence_length: A [batch_size] tensor, containing sequence lengths.
Returns:
- decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
+ decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
Contains the highest scoring tag indices.
- best_score: A [batch_size] vector, containing the score of `decode_tags`.
+ best_score: A [batch_size] tensor, containing the score of decode_tags.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).