diff options
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_functor.cc | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc index 7eba82899f..cf5408123f 100644 --- a/tensorflow/core/kernels/scatter_functor.cc +++ b/tensorflow/core/kernels/scatter_functor.cc @@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { // Forward declarations of the functor specializations for GPU. -#define DECLARE_GPU_SPECS_OP(T, Index, op) \ - template <> \ - Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \ - OpKernelContext* c, const GPUDevice& d, \ - typename TTypes<T>::Matrix params, \ - typename TTypes<T>::ConstMatrix updates, \ - typename TTypes<Index>::ConstFlat indices); \ - extern template struct ScatterFunctor<GPUDevice, T, Index, op>; +#define DECLARE_GPU_SPECS_OP(T, Index, op) \ + template <> \ + Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes<T>::Matrix params, \ + typename TTypes<T>::ConstMatrix updates, \ + typename TTypes<Index>::ConstFlat indices); \ + extern template struct ScatterFunctor<GPUDevice, T, Index, op>; \ + template <> \ + Index ScatterScalarFunctor<GPUDevice, T, Index, op>::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes<T>::Matrix params, \ + const typename TTypes<T>::ConstScalar update, \ + typename TTypes<Index>::ConstFlat indices); \ + extern template struct ScatterScalarFunctor<GPUDevice, T, Index, op>; #define DECLARE_GPU_SPECS_INDEX(T, Index) \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DECLARE_GPU_SPECS(T) \ DECLARE_GPU_SPECS_INDEX(T, int32); \ |