diff options
author | 2018-08-24 14:15:10 -0700 | |
---|---|---|
committer | 2018-08-24 14:23:44 -0700 | |
commit | 2bfd7f4ac7f627cd63c2e723fbcfd74e2daaee4b (patch) | |
tree | ebc3949eb3b1c86b9e252247f48cd3444f9059bc /tensorflow/contrib/lite/kernels/internal | |
parent | 0b17e5d00b11ee84ec9454e3913d0605b57be4ab (diff) |
Quantize mean operator for uint8.
PiperOrigin-RevId: 210154945
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 3492a6c2f9..ff77f61191 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4224,6 +4224,70 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, } } +// Computes the mean of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis for quantized values. +template <typename T, typename U> +inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale, + const int* input_dims, const int input_num_dims, + T* output_data, int32 output_zero_point, float output_scale, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis, U* temp_sum) { + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast<size_t>(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits<size_t>::max() / current) { + return false; + } + num_outputs *= current; + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = T(); + temp_sum[idx] = U(); + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, temp_sum)) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + U num_elements_in_axis = 1; + for (int idx = 0; idx < num_resolved_axis; ++idx) { + size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]); + // Overflow prevention. + if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) { + return false; + } + num_elements_in_axis *= current; + } + + if (num_elements_in_axis > 0) { + const float scale = input_scale / output_scale; + const float bias = -input_zero_point * scale; + for (size_t idx = 0; idx < num_outputs; ++idx) { + float float_mean = static_cast<float>(temp_sum[idx]) / + static_cast<float>(num_elements_in_axis); + + // Convert to float value. + output_data[idx] = + static_cast<T>(round(float_mean * scale + bias)) + output_zero_point; + } + } + return true; +} + template <typename T> void Minimum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, |