diff options
author | 2018-08-21 11:54:42 -0700 | |
---|---|---|
committer | 2018-08-21 12:03:30 -0700 | |
commit | 81d90de884ad6005e57e0d7d333e8476659d00c2 (patch) | |
tree | 86944a1eb9122dc5758aa9ccb2a3d077c9788e71 /tensorflow/contrib/lite/kernels/internal | |
parent | 9158b1b83a0128fc41bfccd80fe26d8231fe958b (diff) |
Support reduce_min
PiperOrigin-RevId: 209634537
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 96 |
1 files changed, 56 insertions, 40 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 556049d8a6..2ebc6084be 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3896,15 +3896,18 @@ inline bool InitTensorDataForReduce(const int* dims, const int num_dims, return true; } -// Computes the sum of elements across dimensions given in axis. +// Computes the generic value (i.e., sum/max/min/prod) of elements across +// dimensions given in axis. It needs to pass in init_value and reducer. template <typename T> -inline bool Sum(const T* input_data, const int* input_dims, - const int input_num_dims, T* output_data, - const int* output_dims, const int output_num_dims, - const int* axis, const int num_axis_dimensions, bool keep_dims, - int* temp_index, int* resolved_axis) { +inline bool ReduceGeneric(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis, + T init_value, + T reducer(const T current, const T in)) { // Reset output data. - if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(0), + if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value, output_data)) { return false; } @@ -3916,9 +3919,25 @@ inline bool Sum(const T* input_data, const int* input_dims, return false; } - return ReduceSumImpl<T, T>(input_data, input_dims, output_dims, - input_num_dims, output_num_dims, resolved_axis, - num_resolved_axis, temp_index, output_data); + return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, reducer, output_data); +} + +// Computes the sum of elements across dimensions given in axis. +template <typename T> +inline bool Sum(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis) { + T init_value = static_cast<T>(0); + + auto reducer = [](const T current, const T in) -> T { return current + in; }; + return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data, + output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, init_value, reducer); } // Computes the max of elements across dimensions given in axis. @@ -3929,25 +3948,32 @@ inline bool ReduceMax(const T* input_data, const int* input_dims, const int* axis, const int64_t num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis) { T init_value = std::numeric_limits<T>::lowest(); - // Reset output data. - if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value, - output_data)) { - return false; - } - - // Resolve axis. - int num_resolved_axis = 0; - if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, - &num_resolved_axis)) { - return false; - } auto reducer = [](const T current, const T in) -> T { return (in > current) ? in : current; }; - return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims, - output_num_dims, resolved_axis, num_resolved_axis, - temp_index, reducer, output_data); + return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data, + output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, init_value, reducer); +} + +// Computes the min of elements across dimensions given in axis. +template <typename T> +inline bool ReduceMin(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis) { + T init_value = std::numeric_limits<T>::max(); + + auto reducer = [](const T current, const T in) -> T { + return (in < current) ? in : current; + }; + return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data, + output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, init_value, reducer); } // Computes the prod of elements across dimensions given in axis. @@ -3957,23 +3983,13 @@ inline bool ReduceProd(const T* input_data, const int* input_dims, const int* output_dims, const int output_num_dims, const int* axis, const int64_t num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis) { - // Reset output data. - if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(1), - output_data)) { - return false; - } - - // Resolve axis. - int num_resolved_axis = 0; - if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, - &num_resolved_axis)) { - return false; - } + T init_value = static_cast<T>(1); auto reducer = [](const T current, const T in) -> T { return in * current; }; - return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims, - output_num_dims, resolved_axis, num_resolved_axis, - temp_index, reducer, output_data); + return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data, + output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, init_value, reducer); } // Computes the mean of elements across dimensions given in axis. |