diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/xent_op_gpu.cu.cc | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc index 05ee7da490..2c0c0b3a02 100644 --- a/tensorflow/core/kernels/xent_op_gpu.cu.cc +++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc @@ -31,12 +31,17 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { template <typename T> struct XentFunctor<GPUDevice, T> { - void operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix logits, + void operator()(const GPUDevice &d, + const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, + const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, + const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, + typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, typename TTypes<T>::Matrix backprop) { - XentEigenImpl<GPUDevice, T>::Compute(d, logits, labels, scratch, loss, + XentEigenImpl<GPUDevice, T>::Compute(d, shape, logits_bcast, labels_bcast, + logits, labels, scratch, loss, backprop); } }; |