aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_xent_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/sparse_xent_op.cc')
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.cc52
1 files changed, 28 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 34411c9bbb..48124d20af 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
- const Tensor& logits_in = context->input(0);
- const Tensor& labels_in = context->input(1);
- OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(),
+ const Tensor& logits = context->input(0);
+ const Tensor& labels = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
+ errors::InvalidArgument("logits must be 2-D, but got shape ",
+ logits.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
+ errors::InvalidArgument("labels must be 1-D, but got shape ",
+ labels.shape().DebugString()));
+ OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
errors::InvalidArgument(
- "logits first dimension must match labels size. logits shape=",
- logits_in.shape().DebugString(), " labels shape=",
- labels_in.shape().DebugString()));
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
- // As we already tested that both inputs have the same shape no need to
- // check that "labels" is a matrix too.
-
- // loss is 1-D (one per example), and size is batch_size.
+ "logits and labels must have the same first dimension, "
+ "got logits shape ",
+ logits.shape().DebugString(), " and labels shape ",
+ labels.shape().DebugString()));
+ OP_REQUIRES(context, logits.dim_size(1) > 0,
+ errors::InvalidArgument(
+ "Must have at least one class, but got logits shape ",
+ logits.shape().DebugString()));
Tensor scratch;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({logits_in.dim_size(0)}),
- &scratch));
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ labels.shape(), &scratch));
Tensor* loss_out = nullptr;
OP_REQUIRES_OK(context,
- context->allocate_output(
- 0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+ context->allocate_output(0, labels.shape(), &loss_out));
Tensor* back_out = nullptr;
OP_REQUIRES_OK(context,
- context->allocate_output(1, logits_in.shape(), &back_out));
-
- functor::SparseXentFunctor<Device, T, Index> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- labels_in.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
- back_out->matrix<T>());
+ context->allocate_output(1, logits.shape(), &back_out));
+
+ if (logits.dim_size(0) > 0) {
+ functor::SparseXentFunctor<Device, T, Index> functor;
+ functor(context->eigen_device<Device>(), logits.matrix<T>(),
+ labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ }
}
};