aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xent_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r--tensorflow/core/kernels/xent_op.cc10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index dc21cee3a8..0f8d027caa 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -67,10 +67,12 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
// Try to reuse the logits_in buffer for the backprop output.
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 1, logits_in.shape(), &back_out));
- functor::XentFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
- back_out->matrix<T>());
+ if (logits_in.dim_size(0) > 0) {
+ functor::XentFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ }
}
};