aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/ctc_loss_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-11 11:38:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 11:52:25 -0700
commitbe15e9eb12a2c61f9b0fef98e94967e64af1f6a1 (patch)
tree2ead3730deef7541637a2bf91dc45721c12c5136 /tensorflow/core/kernels/ctc_loss_op.cc
parent3e4c2f6408beaecd3ab3bf8dd199bc5a8d150361 (diff)
Add option to CTCLoss to skip items with shorter labels than inputs (returns 0-gradient).
The default behavior is preserved, namely to return an InvalidArgument which stops training when such an invalid item is encountered. PiperOrigin-RevId: 155774918
Diffstat (limited to 'tensorflow/core/kernels/ctc_loss_op.cc')
-rw-r--r--tensorflow/core/kernels/ctc_loss_op.cc7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc
index 05d0169b11..426382edec 100644
--- a/tensorflow/core/kernels/ctc_loss_op.cc
+++ b/tensorflow/core/kernels/ctc_loss_op.cc
@@ -42,6 +42,8 @@ class CTCLossOp : public OpKernel {
&preprocess_collapse_repeated_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
+ &ignore_longer_outputs_than_inputs_));
}
void Compute(OpKernelContext* ctx) override {
@@ -150,12 +152,15 @@ class CTCLossOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss(
seq_len_t, labels_t, input_list_t,
preprocess_collapse_repeated_, ctc_merge_repeated_,
- &loss_t, &gradient_list_t, &workers));
+ ignore_longer_outputs_than_inputs_, &loss_t,
+ &gradient_list_t, &workers));
}
private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
+ bool ignore_longer_outputs_than_inputs_;
+
TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp);
};