#if GOOGLE_CUDA #define EIGEN_USE_GPU #include "tensorflow/core/kernels/xent_op.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; // Partial specialization for a GPUDevice, that uses the Eigen implementation // from XentEigenImpl. namespace functor { template struct XentFunctor { void operator()(const GPUDevice& d, typename TTypes::ConstMatrix logits, typename TTypes::ConstMatrix labels, typename TTypes::Matrix scratch, typename TTypes::Vec loss, typename TTypes::Matrix backprop) { XentEigenImpl::Compute(d, logits, labels, scratch, loss, backprop); } }; } // end namespace functor // Instantiate the GPU implementation for float. template struct functor::XentFunctor; } // end namespace tensorflow #endif // GOOGLE_CUDA