diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-11 11:38:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-11 11:52:25 -0700 |
commit | be15e9eb12a2c61f9b0fef98e94967e64af1f6a1 (patch) | |
tree | 2ead3730deef7541637a2bf91dc45721c12c5136 /tensorflow/core/kernels/ctc_loss_op.cc | |
parent | 3e4c2f6408beaecd3ab3bf8dd199bc5a8d150361 (diff) |
Add option to CTCLoss to skip items with shorter labels than inputs (returns 0-gradient).
The default behavior is preserved, namely to return an InvalidArgument which stops
training when such an invalid item is encountered.
PiperOrigin-RevId: 155774918
Diffstat (limited to 'tensorflow/core/kernels/ctc_loss_op.cc')
-rw-r--r-- | tensorflow/core/kernels/ctc_loss_op.cc | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 05d0169b11..426382edec 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -42,6 +42,8 @@ class CTCLossOp : public OpKernel { &preprocess_collapse_repeated_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs", + &ignore_longer_outputs_than_inputs_)); } void Compute(OpKernelContext* ctx) override { @@ -150,12 +152,15 @@ class CTCLossOp : public OpKernel { OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss( seq_len_t, labels_t, input_list_t, preprocess_collapse_repeated_, ctc_merge_repeated_, - &loss_t, &gradient_list_t, &workers)); + ignore_longer_outputs_than_inputs_, &loss_t, + &gradient_list_t, &workers)); } private: bool preprocess_collapse_repeated_; bool ctc_merge_repeated_; + bool ignore_longer_outputs_than_inputs_; + TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp); }; |