diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r-- | tensorflow/core/kernels/xent_op.cc | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index dc21cee3a8..0f8d027caa 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -67,10 +67,12 @@ class SoftmaxXentWithLogitsOp : public OpKernel { // Try to reuse the logits_in buffer for the backprop output. OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 1, logits_in.shape(), &back_out)); - functor::XentFunctor<Device, T> functor; - functor(context->eigen_device<Device>(), logits_in.matrix<T>(), - labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), - back_out->matrix<T>()); + if (logits_in.dim_size(0) > 0) { + functor::XentFunctor<Device, T> functor; + functor(context->eigen_device<Device>(), logits_in.matrix<T>(), + labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), + back_out->matrix<T>()); + } } }; |