diff options
Diffstat (limited to 'tensorflow/core/kernels/segment_reduction_ops.h')
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.h | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 8ed990a1e0..ee09c213b7 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -26,6 +26,17 @@ namespace tensorflow { class OpKernelContext; namespace functor { +// BaseFunctor for definition of UnsorteSegmentReductionOp +// for usage without templates. +template <typename Device, typename T, typename Index> +struct UnsortedSegmentBaseFunctor{ + virtual ~UnsortedSegmentBaseFunctor(){} + virtual void operator()(OpKernelContext* ctx, const Device& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output){}; +}; // Functor for UnsortedSegmentSumOp. // 'output_rows': the number of output segments (unique segment ids in @@ -37,7 +48,7 @@ namespace functor { // 'data': input data tensor. // 'output': output reshaped to {output_rows, output.size/output_rows} template <typename Device, typename T, typename Index> -struct UnsortedSegmentSumFunctor { +struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> { void operator()(OpKernelContext* ctx, const Device& d, const Index output_rows, const TensorShape& segment_ids_shape, typename TTypes<Index>::ConstFlat segment_ids, @@ -45,6 +56,23 @@ struct UnsortedSegmentSumFunctor { typename TTypes<T, 2>::Tensor output); }; +// Functor for UnsortedSegmentMaxOp. +// 'output_rows': the number of output segments (unique segment ids in +// 'segment_ids'). +// 'segment_ids_shape': shape of 'segment_ids' tensor. +// 'segment_ids': unsorted map from input to output segment ids at which to +// perform segment sum operation. +// 'data_size': size of input data tensor. +// 'data': input data tensor. +// 'output': output reshaped to {output_rows, output.size/output_rows} +template <typename Device, typename T, typename Index> +struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> { + void operator()(OpKernelContext* ctx, const Device& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes<Index>::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes<T, 2>::Tensor output); +}; } // namespace functor } // namespace tensorflow |