// See docs in ../ops/nn_ops.cc. #define EIGEN_USE_THREADS #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/public/tensor_shape.h" #include "tensorflow/core/kernels/xent_op.h" #include "tensorflow/core/public/tensor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template class SoftmaxXentWithLogitsOp : public OpKernel { public: explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { const Tensor& logits_in = context->input(0); const Tensor& labels_in = context->input(1); OP_REQUIRES(context, logits_in.IsSameSize(labels_in), errors::InvalidArgument( "logits and labels must be same size: logits_size=", logits_in.shape().DebugString(), " labels_size=", labels_in.shape().DebugString())); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), errors::InvalidArgument("logits must be 2-dimensional")); // As we already tested that both inputs have the same shape no need to // check that "labels" is a matrix too. // loss is 1-D (one per example), and size is batch_size. Tensor scratch; OP_REQUIRES_OK( context, context->allocate_temp(DataTypeToEnum::value, TensorShape({logits_in.dim_size(0), 1}), &scratch)); Tensor* loss_out = nullptr; OP_REQUIRES_OK(context, context->allocate_output( 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); Tensor* back_out = nullptr; OP_REQUIRES_OK(context, context->allocate_output(1, logits_in.shape(), &back_out)); functor::XentFunctor functor; functor(context->eigen_device(), logits_in.matrix(), labels_in.matrix(), scratch.matrix(), loss_out->vec(), back_out->matrix()); } }; // Partial specialization for a CPUDevice, that uses the Eigen implementation // from XentEigenImpl. namespace functor { template struct XentFunctor { void operator()(const CPUDevice& d, typename TTypes::ConstMatrix logits, typename TTypes::ConstMatrix labels, typename TTypes::Matrix scratch, typename TTypes::Vec loss, typename TTypes::Matrix backprop) { XentEigenImpl::Compute(d, logits, labels, scratch, loss, backprop); } }; } // namespace functor REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_CPU) .TypeConstraint("T"), SoftmaxXentWithLogitsOp); REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_CPU) .TypeConstraint("T"), SoftmaxXentWithLogitsOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") .Device(DEVICE_GPU) .TypeConstraint("T"), SoftmaxXentWithLogitsOp); #endif // GOOGLE_CUDA } // namespace tensorflow