diff options
author | Martin Wicke <wicke@google.com> | 2016-06-03 14:23:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-03 15:28:30 -0700 |
commit | 60796d7c0d401e5e7b7a139f165e78ce778583be (patch) | |
tree | 2ccef0ac177be5e467e403ccc4e18cb195f4ef90 /tensorflow/core/kernels/segment_reduction_ops.cc | |
parent | 349072f401952f0aba5240160b1ad6bf9a64bf17 (diff) |
Merge changes from github.
Change: 124012080
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.cc | 88 |
1 files changed, 43 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 76a97e782b..09eed40bd0 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -172,50 +172,48 @@ class SegmentReductionOp : public OpKernel { } }; -#define REGISTER_CPU_KERNELS(type, index_type) \ - REGISTER_KERNEL_BUILDER( \ - Name("SegmentSum") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentReductionOp<CPUDevice, type, index_type, \ - Eigen::internal::SumReducer<type>>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SegmentMean") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentReductionOp<CPUDevice, type, index_type, \ - Eigen::internal::MeanReducer<type>>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SegmentProd") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentReductionOp<CPUDevice, type, index_type, \ - Eigen::internal::ProdReducer<type>>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SegmentMin") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentReductionOp<CPUDevice, type, index_type, \ - Eigen::internal::MinReducer<type>>); \ - REGISTER_KERNEL_BUILDER( \ - Name("SegmentMax") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .TypeConstraint<index_type>("Tindices"), \ - SegmentReductionOp<CPUDevice, type, index_type, \ - Eigen::internal::MaxReducer<type>>); - -#define REGISTER_CPU_KERNELS_ALL(type) \ - REGISTER_CPU_KERNELS(type, int32); \ - REGISTER_CPU_KERNELS(type, int64); - -TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS_ALL); -#undef REGISTER_CPU_KERNELS -#undef REGISTER_CPU_KERNELS_ALL +#define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type) \ + REGISTER_KERNEL_BUILDER( \ + Name(name) \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<index_type>("Tindices"), \ + SegmentReductionOp<CPUDevice, type, index_type, functor>) + +#define REGISTER_REAL_CPU_KERNELS(type, index_type) \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentSum", Eigen::internal::SumReducer<type>, type, index_type); \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type); \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type); \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentMin", Eigen::internal::MinReducer<type>, type, index_type); \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentMax", Eigen::internal::MaxReducer<type>, type, index_type) + +#define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentSum", Eigen::internal::SumReducer<type>, type, index_type); \ + REGISTER_CPU_KERNEL_SEGMENT( \ + "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type) + +#define REGISTER_REAL_CPU_KERNELS_ALL(type) \ + REGISTER_REAL_CPU_KERNELS(type, int32); \ + REGISTER_REAL_CPU_KERNELS(type, int64) + +#define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \ + REGISTER_COMPLEX_CPU_KERNELS(type, int32); \ + REGISTER_COMPLEX_CPU_KERNELS(type, int64) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL); +REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64); +REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); +#undef REGISTER_CPU_KERNEL_SEGMENT +#undef REGISTER_REAL_CPU_KERNELS +#undef REGISTER_COMPLEX_CPU_KERNELS +#undef REGISTER_REAL_CPU_KERNELS_ALL +#undef REGISTER_COMPLEX_CPU_KERNELS_ALL // Similar to SegmentReductionOp but can handle unsorted segment definitions and // specifying size of output. @@ -285,7 +283,7 @@ class UnsortedSegmentSumOp : public OpKernel { REGISTER_CPU_UNSORTED_KERNELS(type, int32); \ REGISTER_CPU_UNSORTED_KERNELS(type, int64); -TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL); +TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL); #undef REGISTER_CPU_UNSORTED_KERNELS #undef REGISTER_CPU_UNSORTED_KERNELS_ALL |