diff options
Diffstat (limited to 'tensorflow/core/kernels/sparse_xent_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_xent_op.cc | 52 |
1 files changed, 28 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc index 34411c9bbb..48124d20af 100644 --- a/tensorflow/core/kernels/sparse_xent_op.cc +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - const Tensor& logits_in = context->input(0); - const Tensor& labels_in = context->input(1); - OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(), + const Tensor& logits = context->input(0); + const Tensor& labels = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()), + errors::InvalidArgument("logits must be 2-D, but got shape ", + logits.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()), + errors::InvalidArgument("labels must be 1-D, but got shape ", + labels.shape().DebugString())); + OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0), errors::InvalidArgument( - "logits first dimension must match labels size. logits shape=", - logits_in.shape().DebugString(), " labels shape=", - labels_in.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), - errors::InvalidArgument("logits must be 2-dimensional")); - // As we already tested that both inputs have the same shape no need to - // check that "labels" is a matrix too. - - // loss is 1-D (one per example), and size is batch_size. + "logits and labels must have the same first dimension, " + "got logits shape ", + logits.shape().DebugString(), " and labels shape ", + labels.shape().DebugString())); + OP_REQUIRES(context, logits.dim_size(1) > 0, + errors::InvalidArgument( + "Must have at least one class, but got logits shape ", + logits.shape().DebugString())); Tensor scratch; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({logits_in.dim_size(0)}), - &scratch)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, + labels.shape(), &scratch)); Tensor* loss_out = nullptr; OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); + context->allocate_output(0, labels.shape(), &loss_out)); Tensor* back_out = nullptr; OP_REQUIRES_OK(context, - context->allocate_output(1, logits_in.shape(), &back_out)); - - functor::SparseXentFunctor<Device, T, Index> functor; - functor(context->eigen_device<Device>(), logits_in.matrix<T>(), - labels_in.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(), - back_out->matrix<T>()); + context->allocate_output(1, logits.shape(), &back_out)); + + if (logits.dim_size(0) > 0) { + functor::SparseXentFunctor<Device, T, Index> functor; + functor(context->eigen_device<Device>(), logits.matrix<T>(), + labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(), + back_out->matrix<T>()); + } } }; |