aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_functor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor.cc')
-rw-r--r--tensorflow/core/kernels/scatter_functor.cc27
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc
index 7eba82899f..cf5408123f 100644
--- a/tensorflow/core/kernels/scatter_functor.cc
+++ b/tensorflow/core/kernels/scatter_functor.cc
@@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Forward declarations of the functor specializations for GPU.
-#define DECLARE_GPU_SPECS_OP(T, Index, op) \
- template <> \
- Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
- OpKernelContext* c, const GPUDevice& d, \
- typename TTypes<T>::Matrix params, \
- typename TTypes<T>::ConstMatrix updates, \
- typename TTypes<Index>::ConstFlat indices); \
- extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
+#define DECLARE_GPU_SPECS_OP(T, Index, op) \
+ template <> \
+ Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ typename TTypes<T>::ConstMatrix updates, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterFunctor<GPUDevice, T, Index, op>; \
+ template <> \
+ Index ScatterScalarFunctor<GPUDevice, T, Index, op>::operator()( \
+ OpKernelContext* c, const GPUDevice& d, \
+ typename TTypes<T>::Matrix params, \
+ const typename TTypes<T>::ConstScalar update, \
+ typename TTypes<Index>::ConstFlat indices); \
+ extern template struct ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \