aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/segment_reduction_ops.cc
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-06-03 14:23:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-03 15:28:30 -0700
commit60796d7c0d401e5e7b7a139f165e78ce778583be (patch)
tree2ccef0ac177be5e467e403ccc4e18cb195f4ef90 /tensorflow/core/kernels/segment_reduction_ops.cc
parent349072f401952f0aba5240160b1ad6bf9a64bf17 (diff)
Merge changes from github.
Change: 124012080
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops.cc')
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc88
1 files changed, 43 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 76a97e782b..09eed40bd0 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -172,50 +172,48 @@ class SegmentReductionOp : public OpKernel {
}
};
-#define REGISTER_CPU_KERNELS(type, index_type) \
- REGISTER_KERNEL_BUILDER( \
- Name("SegmentSum") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionOp<CPUDevice, type, index_type, \
- Eigen::internal::SumReducer<type>>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SegmentMean") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionOp<CPUDevice, type, index_type, \
- Eigen::internal::MeanReducer<type>>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SegmentProd") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionOp<CPUDevice, type, index_type, \
- Eigen::internal::ProdReducer<type>>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SegmentMin") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionOp<CPUDevice, type, index_type, \
- Eigen::internal::MinReducer<type>>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SegmentMax") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<index_type>("Tindices"), \
- SegmentReductionOp<CPUDevice, type, index_type, \
- Eigen::internal::MaxReducer<type>>);
-
-#define REGISTER_CPU_KERNELS_ALL(type) \
- REGISTER_CPU_KERNELS(type, int32); \
- REGISTER_CPU_KERNELS(type, int64);
-
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS_ALL);
-#undef REGISTER_CPU_KERNELS
-#undef REGISTER_CPU_KERNELS_ALL
+#define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionOp<CPUDevice, type, index_type, functor>)
+
+#define REGISTER_REAL_CPU_KERNELS(type, index_type) \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentSum", Eigen::internal::SumReducer<type>, type, index_type); \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type); \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type); \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentMin", Eigen::internal::MinReducer<type>, type, index_type); \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentMax", Eigen::internal::MaxReducer<type>, type, index_type)
+
+#define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentSum", Eigen::internal::SumReducer<type>, type, index_type); \
+ REGISTER_CPU_KERNEL_SEGMENT( \
+ "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type)
+
+#define REGISTER_REAL_CPU_KERNELS_ALL(type) \
+ REGISTER_REAL_CPU_KERNELS(type, int32); \
+ REGISTER_REAL_CPU_KERNELS(type, int64)
+
+#define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
+ REGISTER_COMPLEX_CPU_KERNELS(type, int32); \
+ REGISTER_COMPLEX_CPU_KERNELS(type, int64)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
+REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
+REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
+#undef REGISTER_CPU_KERNEL_SEGMENT
+#undef REGISTER_REAL_CPU_KERNELS
+#undef REGISTER_COMPLEX_CPU_KERNELS
+#undef REGISTER_REAL_CPU_KERNELS_ALL
+#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
// Similar to SegmentReductionOp but can handle unsorted segment definitions and
// specifying size of output.
@@ -285,7 +283,7 @@ class UnsortedSegmentSumOp : public OpKernel {
REGISTER_CPU_UNSORTED_KERNELS(type, int32); \
REGISTER_CPU_UNSORTED_KERNELS(type, int64);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
+TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
#undef REGISTER_CPU_UNSORTED_KERNELS
#undef REGISTER_CPU_UNSORTED_KERNELS_ALL