aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/segment_reduction_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops.h')
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h30
1 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 8ed990a1e0..ee09c213b7 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -26,6 +26,17 @@ namespace tensorflow {
class OpKernelContext;
namespace functor {
+// BaseFunctor for definition of UnsorteSegmentReductionOp
+// for usage without templates.
+template <typename Device, typename T, typename Index>
+struct UnsortedSegmentBaseFunctor{
+ virtual ~UnsortedSegmentBaseFunctor(){}
+ virtual void operator()(OpKernelContext* ctx, const Device& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output){};
+};
// Functor for UnsortedSegmentSumOp.
// 'output_rows': the number of output segments (unique segment ids in
@@ -37,7 +48,7 @@ namespace functor {
// 'data': input data tensor.
// 'output': output reshaped to {output_rows, output.size/output_rows}
template <typename Device, typename T, typename Index>
-struct UnsortedSegmentSumFunctor {
+struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
void operator()(OpKernelContext* ctx, const Device& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
@@ -45,6 +56,23 @@ struct UnsortedSegmentSumFunctor {
typename TTypes<T, 2>::Tensor output);
};
+// Functor for UnsortedSegmentMaxOp.
+// 'output_rows': the number of output segments (unique segment ids in
+// 'segment_ids').
+// 'segment_ids_shape': shape of 'segment_ids' tensor.
+// 'segment_ids': unsorted map from input to output segment ids at which to
+// perform segment sum operation.
+// 'data_size': size of input data tensor.
+// 'data': input data tensor.
+// 'output': output reshaped to {output_rows, output.size/output_rows}
+template <typename Device, typename T, typename Index>
+struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
+ void operator()(OpKernelContext* ctx, const Device& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output);
+};
} // namespace functor
} // namespace tensorflow