aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/ctc_loss_op.cc7
-rw-r--r--tensorflow/core/ops/ctc_ops.cc4
-rw-r--r--tensorflow/core/util/ctc/ctc_loss_calculator.h69
-rw-r--r--tensorflow/python/ops/ctc_ops.py13
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt2
5 files changed, 64 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc
index 05d0169b11..426382edec 100644
--- a/tensorflow/core/kernels/ctc_loss_op.cc
+++ b/tensorflow/core/kernels/ctc_loss_op.cc
@@ -42,6 +42,8 @@ class CTCLossOp : public OpKernel {
&preprocess_collapse_repeated_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
+ &ignore_longer_outputs_than_inputs_));
}
void Compute(OpKernelContext* ctx) override {
@@ -150,12 +152,15 @@ class CTCLossOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
seq_len_t, labels_t, input_list_t,
preprocess_collapse_repeated_, ctc_merge_repeated_,
- &loss_t, &gradient_list_t, &workers));
+ ignore_longer_outputs_than_inputs_, &loss_t,
+ &gradient_list_t, &workers));
}
private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
+ bool ignore_longer_outputs_than_inputs_;
+
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
};
diff --git a/tensorflow/core/ops/ctc_ops.cc b/tensorflow/core/ops/ctc_ops.cc
index c94ce577c0..3d8c533935 100644
--- a/tensorflow/core/ops/ctc_ops.cc
+++ b/tensorflow/core/ops/ctc_ops.cc
@@ -31,6 +31,7 @@ REGISTER_OP("CTCLoss")
.Input("sequence_length: int32")
.Attr("preprocess_collapse_repeated: bool = false")
.Attr("ctc_merge_repeated: bool = true")
+ .Attr("ignore_longer_outputs_than_inputs: bool = false")
.Output("loss: float")
.Output("gradient: float")
.SetShapeFn([](InferenceContext* c) {
@@ -75,6 +76,9 @@ preprocess_collapse_repeated: Scalar, if true then repeated labels are
ctc_merge_repeated: Scalar. If set to false, *during* CTC calculation
repeated non-blank labels will not be merged and are interpreted as
individual labels. This is a simplified version of CTC.
+ignore_longer_outputs_than_inputs: Scalar. If set to true, during CTC
+ calculation items have longer input sequences than output sequences
+ are ignored by returning zero-gradient for those items.
loss: A vector (batch) containing log-probabilities.
gradient: The gradient of `loss`. 3-D, shape:
`(max_time x batch_size x num_classes)`.
diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h
index eacadd65af..567bad38c3 100644
--- a/tensorflow/core/util/ctc/ctc_loss_calculator.h
+++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h
@@ -65,7 +65,8 @@ class CTCLossCalculator {
Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels,
const std::vector<MatrixIn>& inputs,
bool preprocess_collapse_repeated,
- bool ctc_merge_repeated, VectorOut* loss,
+ bool ctc_merge_repeated,
+ bool ignore_longer_outputs_than_inputs, VectorOut* loss,
std::vector<MatrixOut>* gradients,
DeviceBase::CpuWorkerThreads* workers = nullptr) const;
@@ -90,7 +91,8 @@ class CTCLossCalculator {
// batch. Return value:
// max_{b in batch_size} l_primes[b].size()
template <typename Vector>
- Status PopulateLPrimes(bool preprocess_collapse_repeated, int batch_size,
+ Status PopulateLPrimes(bool preprocess_collapse_repeated,
+ bool ignore_longer_outputs_than_inputs, int batch_size,
int num_classes, const Vector& seq_len,
const LabelSequences& labels, size_t* max_u_prime,
LabelSequences* l_primes) const;
@@ -108,7 +110,8 @@ template <typename VectorIn, typename VectorOut, typename MatrixIn,
Status CTCLossCalculator::CalculateLoss(
const VectorIn& seq_len, const LabelSequences& labels,
const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated,
- bool ctc_merge_repeated, VectorOut* loss, std::vector<MatrixOut>* gradients,
+ bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs,
+ VectorOut* loss, std::vector<MatrixOut>* gradients,
DeviceBase::CpuWorkerThreads* workers) const {
auto num_time_steps = inputs.size();
@@ -155,20 +158,31 @@ Status CTCLossCalculator::CalculateLoss(
// and calculate the maximum necessary allocation size.
LabelSequences l_primes(batch_size);
size_t max_u_prime = 0;
- Status l_p_ret =
- PopulateLPrimes(preprocess_collapse_repeated, batch_size, num_classes,
- seq_len, labels, &max_u_prime, &l_primes);
+ Status l_p_ret = PopulateLPrimes(
+ preprocess_collapse_repeated, ignore_longer_outputs_than_inputs,
+ batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes);
if (!l_p_ret.ok()) {
return l_p_ret;
}
// Process each item in a batch in parallel, using at most kMaxThreads.
- auto ComputeLossAndGradients = [this, num_classes, &l_primes, &seq_len,
- &inputs, requires_backprop,
- ctc_merge_repeated, &loss, &gradients](
- int64 start_row, int64 limit_row) {
+ auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes,
+ &seq_len, &inputs, requires_backprop,
+ ctc_merge_repeated,
+ ignore_longer_outputs_than_inputs, &loss,
+ &gradients](int64 start_row,
+ int64 limit_row) {
for (int b = start_row; b < limit_row; b++) {
- if (seq_len(b) == 0) {
+ // Return zero gradient for empty sequences or sequences with labels
+ // longer than input, which is not supported by CTC.
+ if (seq_len(b) == 0 ||
+ (ignore_longer_outputs_than_inputs &&
+ labels[b].size() > seq_len(b) - this->output_delay_)) {
+ VLOG(1) << "The sequence length is either zero or shorter than the "
+ "target output (CTC works only with shorter target sequence "
+ "than input sequence). You can turn this into a warning by "
+ "using the flag ignore_longer_outputs_than_inputs - "
+ << b << ": " << str_util::Join(labels[b], " ");
continue;
}
@@ -263,12 +277,11 @@ Status CTCLossCalculator::CalculateLoss(
}
template <typename Vector>
-Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated,
- int batch_size, int num_classes,
- const Vector& seq_len,
- const LabelSequences& labels,
- size_t* max_u_prime,
- LabelSequences* l_primes) const {
+Status CTCLossCalculator::PopulateLPrimes(
+ bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs,
+ int batch_size, int num_classes, const Vector& seq_len,
+ const LabelSequences& labels, size_t* max_u_prime,
+ LabelSequences* l_primes) const {
// labels is a Label array of size batch_size
if (labels.size() != batch_size) {
return errors::InvalidArgument("labels.size() != batch_size: ",
@@ -311,9 +324,6 @@ Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated,
}
}
- // Make sure there is enough time to output the target indices.
- int time = seq_len(b) - output_delay_;
- int required_time = label.size();
for (int l_i : l) {
if (l_i < 0) {
return errors::InvalidArgument(
@@ -325,14 +335,19 @@ Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated,
num_classes, ", batch: ", b, " labels: ", str_util::Join(l, ","));
}
}
- if (required_time > time) {
- return errors::InvalidArgument(
- "Not enough time for target transition sequence ("
- "required: ",
- required_time, ", available: ", time,
- "), skipping data instance in batch: ", b);
+ if (!ignore_longer_outputs_than_inputs) {
+ // Make sure there is enough time to output the target indices.
+ int time = seq_len(b) - output_delay_;
+ int required_time = label.size();
+ if (required_time > time) {
+ return errors::InvalidArgument(
+ "Not enough time for target transition sequence ("
+ "required: ",
+ required_time, ", available: ", time, ")", b,
+ "You can turn this error into a warning by using the flag "
+ "ignore_longer_outputs_than_inputs");
+ }
}
-
// Target indices with blanks before each index and a blank at the end.
// Length U' = 2U + 1.
// Convert l to l_prime
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 69edaa2c40..4ea4d9ed2d 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -30,7 +30,8 @@ from tensorflow.python.ops.nn_grad import _BroadcastMul
# pylint: disable=protected-access, invalid-name
def ctc_loss(labels, inputs, sequence_length,
preprocess_collapse_repeated=False,
- ctc_merge_repeated=True, time_major=True):
+ ctc_merge_repeated=True,
+ ignore_longer_outputs_than_inputs=False, time_major=True):
"""Computes the CTC (Connectionist Temporal Classification) Loss.
This op implements the CTC loss as presented in the article:
@@ -94,6 +95,11 @@ def ctc_loss(labels, inputs, sequence_length,
Untested. Very likely will not learn to output repeated classes.
+ The `ignore_longer_outputs_than_inputs` option allows to specify the behavior
+ of the CTCLoss when dealing with sequences that have longer outputs than
+ inputs. If true, the CTCLoss will simply return zero gradient for those
+ items, otherwise an InvalidArgument error is returned, stopping training.
+
Args:
labels: An `int32` `SparseTensor`.
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
@@ -111,6 +117,8 @@ def ctc_loss(labels, inputs, sequence_length,
preprocess_collapse_repeated: Boolean. Default: False.
If True, repeated labels are collapsed prior to the CTC calculation.
ctc_merge_repeated: Boolean. Default: True.
+ ignore_longer_outputs_than_inputs: Boolean. Default: False.
+ If True, sequences with longer outputs than inputs will be ignored.
time_major: The shape format of the `inputs` Tensors.
If True, these `Tensors` must be shaped `[max_time, batch_size, num_classes]`.
If False, these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
@@ -140,7 +148,8 @@ def ctc_loss(labels, inputs, sequence_length,
labels.values,
sequence_length,
preprocess_collapse_repeated=preprocess_collapse_repeated,
- ctc_merge_repeated=ctc_merge_repeated)
+ ctc_merge_repeated=ctc_merge_repeated,
+ ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs)
return loss
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index 3a448798b2..b1b60fbdcb 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -90,7 +90,7 @@ tf_module {
}
member_method {
name: "ctc_loss"
- argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'True\'], "
+ argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], "
}
member_method {
name: "depthwise_conv2d"