diff options
-rw-r--r-- | tensorflow/core/kernels/ctc_loss_op.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/ops/ctc_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/util/ctc/ctc_loss_calculator.h | 69 | ||||
-rw-r--r-- | tensorflow/python/ops/ctc_ops.py | 13 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.nn.pbtxt | 2 |
5 files changed, 64 insertions, 31 deletions
diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 05d0169b11..426382edec 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -42,6 +42,8 @@ class CTCLossOp : public OpKernel { &preprocess_collapse_repeated_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs", + &ignore_longer_outputs_than_inputs_)); } void Compute(OpKernelContext* ctx) override { @@ -150,12 +152,15 @@ class CTCLossOp : public OpKernel { OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss( seq_len_t, labels_t, input_list_t, preprocess_collapse_repeated_, ctc_merge_repeated_, - &loss_t, &gradient_list_t, &workers)); + ignore_longer_outputs_than_inputs_, &loss_t, + &gradient_list_t, &workers)); } private: bool preprocess_collapse_repeated_; bool ctc_merge_repeated_; + bool ignore_longer_outputs_than_inputs_; + TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp); }; diff --git a/tensorflow/core/ops/ctc_ops.cc b/tensorflow/core/ops/ctc_ops.cc index c94ce577c0..3d8c533935 100644 --- a/tensorflow/core/ops/ctc_ops.cc +++ b/tensorflow/core/ops/ctc_ops.cc @@ -31,6 +31,7 @@ REGISTER_OP("CTCLoss") .Input("sequence_length: int32") .Attr("preprocess_collapse_repeated: bool = false") .Attr("ctc_merge_repeated: bool = true") + .Attr("ignore_longer_outputs_than_inputs: bool = false") .Output("loss: float") .Output("gradient: float") .SetShapeFn([](InferenceContext* c) { @@ -75,6 +76,9 @@ preprocess_collapse_repeated: Scalar, if true then repeated labels are ctc_merge_repeated: Scalar. If set to false, *during* CTC calculation repeated non-blank labels will not be merged and are interpreted as individual labels. This is a simplified version of CTC. +ignore_longer_outputs_than_inputs: Scalar. If set to true, during CTC + calculation items have longer input sequences than output sequences + are ignored by returning zero-gradient for those items. loss: A vector (batch) containing log-probabilities. gradient: The gradient of `loss`. 3-D, shape: `(max_time x batch_size x num_classes)`. diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h index eacadd65af..567bad38c3 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.h +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -65,7 +65,8 @@ class CTCLossCalculator { Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels, const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated, - bool ctc_merge_repeated, VectorOut* loss, + bool ctc_merge_repeated, + bool ignore_longer_outputs_than_inputs, VectorOut* loss, std::vector<MatrixOut>* gradients, DeviceBase::CpuWorkerThreads* workers = nullptr) const; @@ -90,7 +91,8 @@ class CTCLossCalculator { // batch. Return value: // max_{b in batch_size} l_primes[b].size() template <typename Vector> - Status PopulateLPrimes(bool preprocess_collapse_repeated, int batch_size, + Status PopulateLPrimes(bool preprocess_collapse_repeated, + bool ignore_longer_outputs_than_inputs, int batch_size, int num_classes, const Vector& seq_len, const LabelSequences& labels, size_t* max_u_prime, LabelSequences* l_primes) const; @@ -108,7 +110,8 @@ template <typename VectorIn, typename VectorOut, typename MatrixIn, Status CTCLossCalculator::CalculateLoss( const VectorIn& seq_len, const LabelSequences& labels, const std::vector<MatrixIn>& inputs, bool preprocess_collapse_repeated, - bool ctc_merge_repeated, VectorOut* loss, std::vector<MatrixOut>* gradients, + bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs, + VectorOut* loss, std::vector<MatrixOut>* gradients, DeviceBase::CpuWorkerThreads* workers) const { auto num_time_steps = inputs.size(); @@ -155,20 +158,31 @@ Status CTCLossCalculator::CalculateLoss( // and calculate the maximum necessary allocation size. LabelSequences l_primes(batch_size); size_t max_u_prime = 0; - Status l_p_ret = - PopulateLPrimes(preprocess_collapse_repeated, batch_size, num_classes, - seq_len, labels, &max_u_prime, &l_primes); + Status l_p_ret = PopulateLPrimes( + preprocess_collapse_repeated, ignore_longer_outputs_than_inputs, + batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes); if (!l_p_ret.ok()) { return l_p_ret; } // Process each item in a batch in parallel, using at most kMaxThreads. - auto ComputeLossAndGradients = [this, num_classes, &l_primes, &seq_len, - &inputs, requires_backprop, - ctc_merge_repeated, &loss, &gradients]( - int64 start_row, int64 limit_row) { + auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes, + &seq_len, &inputs, requires_backprop, + ctc_merge_repeated, + ignore_longer_outputs_than_inputs, &loss, + &gradients](int64 start_row, + int64 limit_row) { for (int b = start_row; b < limit_row; b++) { - if (seq_len(b) == 0) { + // Return zero gradient for empty sequences or sequences with labels + // longer than input, which is not supported by CTC. + if (seq_len(b) == 0 || + (ignore_longer_outputs_than_inputs && + labels[b].size() > seq_len(b) - this->output_delay_)) { + VLOG(1) << "The sequence length is either zero or shorter than the " + "target output (CTC works only with shorter target sequence " + "than input sequence). You can turn this into a warning by " + "using the flag ignore_longer_outputs_than_inputs - " + << b << ": " << str_util::Join(labels[b], " "); continue; } @@ -263,12 +277,11 @@ Status CTCLossCalculator::CalculateLoss( } template <typename Vector> -Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated, - int batch_size, int num_classes, - const Vector& seq_len, - const LabelSequences& labels, - size_t* max_u_prime, - LabelSequences* l_primes) const { +Status CTCLossCalculator::PopulateLPrimes( + bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs, + int batch_size, int num_classes, const Vector& seq_len, + const LabelSequences& labels, size_t* max_u_prime, + LabelSequences* l_primes) const { // labels is a Label array of size batch_size if (labels.size() != batch_size) { return errors::InvalidArgument("labels.size() != batch_size: ", @@ -311,9 +324,6 @@ Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated, } } - // Make sure there is enough time to output the target indices. - int time = seq_len(b) - output_delay_; - int required_time = label.size(); for (int l_i : l) { if (l_i < 0) { return errors::InvalidArgument( @@ -325,14 +335,19 @@ Status CTCLossCalculator::PopulateLPrimes(bool preprocess_collapse_repeated, num_classes, ", batch: ", b, " labels: ", str_util::Join(l, ",")); } } - if (required_time > time) { - return errors::InvalidArgument( - "Not enough time for target transition sequence (" - "required: ", - required_time, ", available: ", time, - "), skipping data instance in batch: ", b); + if (!ignore_longer_outputs_than_inputs) { + // Make sure there is enough time to output the target indices. + int time = seq_len(b) - output_delay_; + int required_time = label.size(); + if (required_time > time) { + return errors::InvalidArgument( + "Not enough time for target transition sequence (" + "required: ", + required_time, ", available: ", time, ")", b, + "You can turn this error into a warning by using the flag " + "ignore_longer_outputs_than_inputs"); + } } - // Target indices with blanks before each index and a blank at the end. // Length U' = 2U + 1. // Convert l to l_prime diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 69edaa2c40..4ea4d9ed2d 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -30,7 +30,8 @@ from tensorflow.python.ops.nn_grad import _BroadcastMul # pylint: disable=protected-access, invalid-name def ctc_loss(labels, inputs, sequence_length, preprocess_collapse_repeated=False, - ctc_merge_repeated=True, time_major=True): + ctc_merge_repeated=True, + ignore_longer_outputs_than_inputs=False, time_major=True): """Computes the CTC (Connectionist Temporal Classification) Loss. This op implements the CTC loss as presented in the article: @@ -94,6 +95,11 @@ def ctc_loss(labels, inputs, sequence_length, Untested. Very likely will not learn to output repeated classes. + The `ignore_longer_outputs_than_inputs` option allows to specify the behavior + of the CTCLoss when dealing with sequences that have longer outputs than + inputs. If true, the CTCLoss will simply return zero gradient for those + items, otherwise an InvalidArgument error is returned, stopping training. + Args: labels: An `int32` `SparseTensor`. `labels.indices[i, :] == [b, t]` means `labels.values[i]` stores @@ -111,6 +117,8 @@ def ctc_loss(labels, inputs, sequence_length, preprocess_collapse_repeated: Boolean. Default: False. If True, repeated labels are collapsed prior to the CTC calculation. ctc_merge_repeated: Boolean. Default: True. + ignore_longer_outputs_than_inputs: Boolean. Default: False. + If True, sequences with longer outputs than inputs will be ignored. time_major: The shape format of the `inputs` Tensors. If True, these `Tensors` must be shaped `[max_time, batch_size, num_classes]`. If False, these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. @@ -140,7 +148,8 @@ def ctc_loss(labels, inputs, sequence_length, labels.values, sequence_length, preprocess_collapse_repeated=preprocess_collapse_repeated, - ctc_merge_repeated=ctc_merge_repeated) + ctc_merge_repeated=ctc_merge_repeated, + ignore_longer_outputs_than_inputs=ignore_longer_outputs_than_inputs) return loss diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index 3a448798b2..b1b60fbdcb 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -90,7 +90,7 @@ tf_module { } member_method { name: "ctc_loss" - argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'True\'], " + argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], " } member_method { name: "depthwise_conv2d" |