aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 11:54:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 12:03:30 -0700
commit81d90de884ad6005e57e0d7d333e8476659d00c2 (patch)
tree86944a1eb9122dc5758aa9ccb2a3d077c9788e71 /tensorflow/contrib/lite/kernels/internal
parent9158b1b83a0128fc41bfccd80fe26d8231fe958b (diff)
Support reduce_min
PiperOrigin-RevId: 209634537
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h96
1 files changed, 56 insertions, 40 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 556049d8a6..2ebc6084be 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3896,15 +3896,18 @@ inline bool InitTensorDataForReduce(const int* dims, const int num_dims,
return true;
}
-// Computes the sum of elements across dimensions given in axis.
+// Computes the generic value (i.e., sum/max/min/prod) of elements across
+// dimensions given in axis. It needs to pass in init_value and reducer.
template <typename T>
-inline bool Sum(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis) {
+inline bool ReduceGeneric(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int64_t num_axis_dimensions,
+ bool keep_dims, int* temp_index, int* resolved_axis,
+ T init_value,
+ T reducer(const T current, const T in)) {
// Reset output data.
- if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(0),
+ if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
output_data)) {
return false;
}
@@ -3916,9 +3919,25 @@ inline bool Sum(const T* input_data, const int* input_dims,
return false;
}
- return ReduceSumImpl<T, T>(input_data, input_dims, output_dims,
- input_num_dims, output_num_dims, resolved_axis,
- num_resolved_axis, temp_index, output_data);
+ return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, reducer, output_data);
+}
+
+// Computes the sum of elements across dimensions given in axis.
+template <typename T>
+inline bool Sum(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis) {
+ T init_value = static_cast<T>(0);
+
+ auto reducer = [](const T current, const T in) -> T { return current + in; };
+ return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
+ output_dims, output_num_dims, axis,
+ num_axis_dimensions, keep_dims, temp_index,
+ resolved_axis, init_value, reducer);
}
// Computes the max of elements across dimensions given in axis.
@@ -3929,25 +3948,32 @@ inline bool ReduceMax(const T* input_data, const int* input_dims,
const int* axis, const int64_t num_axis_dimensions,
bool keep_dims, int* temp_index, int* resolved_axis) {
T init_value = std::numeric_limits<T>::lowest();
- // Reset output data.
- if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
- output_data)) {
- return false;
- }
-
- // Resolve axis.
- int num_resolved_axis = 0;
- if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
- &num_resolved_axis)) {
- return false;
- }
auto reducer = [](const T current, const T in) -> T {
return (in > current) ? in : current;
};
- return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
- output_num_dims, resolved_axis, num_resolved_axis,
- temp_index, reducer, output_data);
+ return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
+ output_dims, output_num_dims, axis,
+ num_axis_dimensions, keep_dims, temp_index,
+ resolved_axis, init_value, reducer);
+}
+
+// Computes the min of elements across dimensions given in axis.
+template <typename T>
+inline bool ReduceMin(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int64_t num_axis_dimensions,
+ bool keep_dims, int* temp_index, int* resolved_axis) {
+ T init_value = std::numeric_limits<T>::max();
+
+ auto reducer = [](const T current, const T in) -> T {
+ return (in < current) ? in : current;
+ };
+ return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
+ output_dims, output_num_dims, axis,
+ num_axis_dimensions, keep_dims, temp_index,
+ resolved_axis, init_value, reducer);
}
// Computes the prod of elements across dimensions given in axis.
@@ -3957,23 +3983,13 @@ inline bool ReduceProd(const T* input_data, const int* input_dims,
const int* output_dims, const int output_num_dims,
const int* axis, const int64_t num_axis_dimensions,
bool keep_dims, int* temp_index, int* resolved_axis) {
- // Reset output data.
- if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast<T>(1),
- output_data)) {
- return false;
- }
-
- // Resolve axis.
- int num_resolved_axis = 0;
- if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
- &num_resolved_axis)) {
- return false;
- }
+ T init_value = static_cast<T>(1);
auto reducer = [](const T current, const T in) -> T { return in * current; };
- return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
- output_num_dims, resolved_axis, num_resolved_axis,
- temp_index, reducer, output_data);
+ return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
+ output_dims, output_num_dims, axis,
+ num_axis_dimensions, keep_dims, temp_index,
+ resolved_axis, init_value, reducer);
}
// Computes the mean of elements across dimensions given in axis.