diff options
author | 2018-09-13 09:35:01 -0700 | |
---|---|---|
committer | 2018-09-13 09:42:05 -0700 | |
commit | 5ae1c93473ae690d4a7b9389b1219179cb2504a3 (patch) | |
tree | bfafbd3138b0c56e2dea1dc23947b9742e241d04 /tensorflow/contrib/lite/kernels/internal | |
parent | 88a7c5b98fc1ccb56134003ba3dc88a09385c0a7 (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 212826308
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h | 688 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/types.h | 42 |
2 files changed, 473 insertions, 257 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 2c8e8f90e3..baed8f4993 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -260,16 +260,16 @@ inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { return true; } -inline void AddBiasAndEvalActivationFunction(const float* bias_data, - const Dims<4>& bias_dims, - float* array_data, - const Dims<4>& array_dims, - float output_activation_min, - float output_activation_max) { +inline void AddBiasAndEvalActivationFunction(float output_activation_min, + float output_activation_max, + const RuntimeShape& bias_shape, + const float* bias_data, + const RuntimeShape& array_shape, + float* array_data) { #ifdef USE_NEON gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); - const int bias_size = FlatSize(bias_dims); - const int array_size = FlatSize(array_dims); + const int bias_size = bias_shape.FlatSize(); + const int array_size = array_shape.FlatSize(); TFLITE_DCHECK_EQ((array_size % bias_size), 0); float* array_ptr = array_data; float* array_end_ptr = array_ptr + array_size; @@ -319,8 +319,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data, } #else // not NEON gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); - const int bias_size = FlatSize(bias_dims); - const int array_size = FlatSize(array_dims); + const int bias_size = bias_shape.FlatSize(); + const int array_size = array_shape.FlatSize(); TFLITE_DCHECK_EQ((array_size % bias_size), 0); for (int array_offset = 0; array_offset < array_size; array_offset += bias_size) { @@ -333,6 +333,19 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data, #endif } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims, + float output_activation_min, + float output_activation_max) { + AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max, + DimsToShape(bias_dims), bias_data, + DimsToShape(array_dims), array_data); +} + // Note: This to be converted to RuntimeShapes along with Conv. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> @@ -1672,12 +1685,16 @@ inline void ShuffledFullyConnected( } template <typename T> -inline void ExtractPatchIntoBufferColumn( - const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, - int stride_width, int stride_height, int pad_width, int pad_height, - int in_width, int in_height, int in_depth, int single_buffer_length, - int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) { +inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w, + int h, int b, int kheight, int kwidth, + int stride_width, int stride_height, + int pad_width, int pad_height, + int in_width, int in_height, + int in_depth, int single_buffer_length, + int buffer_id, const T* in_data, + T* conv_buffer_data, uint8 zero_byte) { gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn"); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); // This chunk of code reshapes all the inputs corresponding to // output (b, h, w) to a column vector in conv_buffer(:, buffer_id). const int kwidth_times_indepth = kwidth * in_depth; @@ -1699,7 +1716,7 @@ inline void ExtractPatchIntoBufferColumn( const int output_row_offset = (buffer_id * single_buffer_length); int out_offset = output_row_offset + (h_offset * kwidth + w_offset) * in_depth; - int in_offset = Offset(input_dims, 0, iw_start, ih_start, b); + int in_offset = Offset(input_shape, b, ih_start, iw_start, 0); // Express all of the calculations as padding around the input patch. const int top_padding = h_offset; @@ -1713,7 +1730,7 @@ inline void ExtractPatchIntoBufferColumn( // patch that are off the edge of the input image. if (top_padding > 0) { const int top_row_elements = (top_padding * kwidth * in_depth); - memset(conv_buffer_data + output_row_offset, byte_zero, + memset(conv_buffer_data + output_row_offset, zero_byte, (top_row_elements * sizeof(T))); } @@ -1730,14 +1747,14 @@ inline void ExtractPatchIntoBufferColumn( for (int ih = ih_start; ih < ih_end; ++ih) { if (left_padding > 0) { const int left_start = (out_offset - (left_padding * in_depth)); - memset(conv_buffer_data + left_start, byte_zero, + memset(conv_buffer_data + left_start, zero_byte, (left_padding * in_depth * sizeof(T))); } memcpy(conv_buffer_data + out_offset, in_data + in_offset, single_row_num * sizeof(T)); if (right_padding > 0) { const int right_start = (out_offset + single_row_num); - memset(conv_buffer_data + right_start, byte_zero, + memset(conv_buffer_data + right_start, zero_byte, (right_padding * in_depth * sizeof(T))); } out_offset += kwidth_times_indepth; @@ -1752,61 +1769,64 @@ inline void ExtractPatchIntoBufferColumn( const int bottom_start = output_row_offset + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth); - memset(conv_buffer_data + bottom_start, byte_zero, + memset(conv_buffer_data + bottom_start, zero_byte, (bottom_row_elements * sizeof(T))); } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. template <typename T> -void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, - const Dims<4>& filter_dims, int stride_width, - int stride_height, int dilation_width_factor, - int dilation_height_factor, int pad_width, int pad_height, - const Dims<4>& output_dims, uint8 byte_zero, - T* im2col_data) { +inline void ExtractPatchIntoBufferColumn( + const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, + int stride_width, int stride_height, int pad_width, int pad_height, + int in_width, int in_height, int in_depth, int single_buffer_length, + int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) { + ExtractPatchIntoBufferColumn( + DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width, + stride_height, pad_width, pad_height, in_width, in_height, in_depth, + single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte); +} + +template <typename T> +void DilatedIm2col(const ConvParams& params, uint8 zero_byte, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& filter_shape, + const RuntimeShape& output_shape, T* im2col_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + // For dilated convolution, the input pixels are not contiguous therefore we // can't use the same opitimizations as Im2Col(). Though note this code would // work fine for the non-dilated case too (though likely a bit slower). gemmlowp::ScopedProfilingLabel label("DilatedIm2col"); TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); TFLITE_DCHECK(im2col_data); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - MatchingArraySize(output_dims, 0, filter_dims, 3); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + MatchingDim(output_shape, 3, filter_shape, 0); // Construct the MxN sized im2col matrix. // The rows M, are sub-ordered B x H x W - Dims<4> row_dims; - row_dims.sizes[0] = output_width; - row_dims.sizes[1] = output_height; - row_dims.sizes[2] = batches; - row_dims.sizes[3] = 1; - ComputeStrides(&row_dims); - + const RuntimeShape row_shape({1, batches, output_height, output_width}); // The columns, N, are sub-ordered Kh x Kw x Din - Dims<4> col_dims; - col_dims.sizes[0] = input_depth; - col_dims.sizes[1] = filter_width; - col_dims.sizes[2] = filter_height; - col_dims.sizes[3] = 1; - ComputeStrides(&col_dims); - + const RuntimeShape col_shape({1, filter_height, filter_width, input_depth}); // Use dimensions M and N to construct dims for indexing directly into im2col - Dims<4> im2col_dims; - im2col_dims.sizes[0] = FlatSize(col_dims); - im2col_dims.sizes[1] = FlatSize(row_dims); - im2col_dims.sizes[2] = 1; - im2col_dims.sizes[3] = 1; - ComputeStrides(&im2col_dims); + const RuntimeShape im2col_shape( + {1, 1, row_shape.FlatSize(), col_shape.FlatSize()}); // Loop through the output rows (B x H x W) for (int batch = 0; batch < batches; ++batch) { @@ -1814,7 +1834,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, for (int out_x = 0; out_x < output_width; ++out_x) { // Each im2col row is an output pixel. Arrange the input data in this // row in an order we can conveniently multiply with the filter data. - int row_offset = Offset(row_dims, out_x, out_y, batch, 0); + int row_offset = Offset(row_shape, 0, batch, out_y, out_x); const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; // Loop through all the pixels of the filter (Kh x Kw) @@ -1825,25 +1845,25 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, // Loop through all the filter pixels in this row. for (int filter_x = 0; filter_x < filter_width; ++filter_x) { const int in_x = in_x_origin + dilation_width_factor * filter_x; - int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0); + int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0); T* dst = im2col_data + - Offset(im2col_dims, col_offset, row_offset, 0, 0); + Offset(im2col_shape, 0, 0, row_offset, col_offset); if ((in_x >= 0) && (in_x < input_width)) { // Filter pixel is within the input, copy the input data. T const* src = - input_data + Offset(input_dims, 0, in_x, in_y, batch); + input_data + Offset(input_shape, batch, in_y, in_x, 0); memcpy(dst, src, input_depth * sizeof(T)); } else { // Filter pixel is outside the input, zero it out. - memset(dst, byte_zero, input_depth * sizeof(T)); + memset(dst, zero_byte, input_depth * sizeof(T)); } } } else { // Filter row is outside the input, zero out the entire filter row. - int col_offset = Offset(col_dims, 0, 0, filter_y, 0); - T* dst = - im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); - memset(dst, byte_zero, filter_width * input_depth * sizeof(T)); + int col_offset = Offset(col_shape, 0, filter_y, 0, 0); + T* dst = im2col_data + + Offset(im2col_shape, 0, 0, row_offset, col_offset); + memset(dst, zero_byte, filter_width * input_depth * sizeof(T)); } } } @@ -1851,21 +1871,49 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. template <typename T> -void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, - int stride_height, int pad_width, int pad_height, int kheight, - int kwidth, uint8 byte_zero, T* output_data, - const Dims<4>& output_dims) { +void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 zero_byte, + T* im2col_data) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + op_params.dilation_width_factor = dilation_width_factor; + op_params.dilation_height_factor = dilation_height_factor; + + DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data, + DimsToShape(filter_dims), DimsToShape(output_dims), + im2col_data); +} + +template <typename T> +void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { gemmlowp::ScopedProfilingLabel label("Im2col"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); int buffer_id = 0; // Loop over the output nodes. @@ -1873,93 +1921,155 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, for (int h = 0; h < output_height; ++h) { for (int w = 0; w < output_width; ++w) { ExtractPatchIntoBufferColumn( - input_dims, w, h, b, kheight, kwidth, stride_width, stride_height, + input_shape, w, h, b, kheight, kwidth, stride_width, stride_height, pad_width, pad_height, input_width, input_height, input_depth, - output_depth, buffer_id, input_data, output_data, byte_zero); + output_depth, buffer_id, input_data, output_data, zero_byte); ++buffer_id; } } } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T> +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, + int stride_height, int pad_width, int pad_height, int kheight, + int kwidth, uint8 zero_byte, T* output_data, + const Dims<4>& output_dims) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + op_params.dilation_width_factor = 1; + op_params.dilation_height_factor = 1; + + Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims), + input_data, DimsToShape(output_dims), output_data); +} + // legacy, for compatibility with old checked-in code template <typename T> void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, int pad_width, int pad_height, int kheight, int kwidth, - uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + uint8 zero_byte, T* output_data, const Dims<4>& output_dims) { Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, - kwidth, byte_zero, output_data, output_dims); + kwidth, zero_byte, output_data, output_dims); } -inline void Conv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int dilation_width_factor, - int dilation_height_factor, int pad_width, int pad_height, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims, - float* im2col_data, const Dims<4>& im2col_dims) { +inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, + float* im2col_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + (void)im2col_data; - (void)im2col_dims; + (void)im2col_shape; gemmlowp::ScopedProfilingLabel label("Conv"); // NB: static_cast<float>(0x00000000h) == 0.0f const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; - const Dims<4>* gemm_input_dims = nullptr; - const int filter_width = ArraySize(filter_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); + const RuntimeShape* gemm_input_shape = nullptr; + const int filter_width = filter_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; if (need_dilated_im2col) { - DilatedIm2col(input_data, input_dims, filter_dims, stride_width, - stride_height, dilation_width_factor, dilation_height_factor, - pad_width, pad_height, output_dims, float_zero_byte, - im2col_data); + DilatedIm2col(params, float_zero_byte, input_shape, input_data, + filter_shape, output_shape, im2col_data); gemm_input_data = im2col_data; - gemm_input_dims = &im2col_dims; + gemm_input_shape = &im2col_shape; } else if (need_im2col) { TFLITE_DCHECK(im2col_data); - Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, float_zero_byte, - im2col_data, im2col_dims); + Im2col(params, filter_height, filter_width, float_zero_byte, input_shape, + input_data, im2col_shape, im2col_data); gemm_input_data = im2col_data; - gemm_input_dims = &im2col_dims; + gemm_input_shape = &im2col_shape; } else { // TODO(aselle): We need to make sure to not send im2col if it is not // needed. TFLITE_DCHECK(!im2col_data); gemm_input_data = input_data; - gemm_input_dims = &input_dims; + gemm_input_shape = &input_shape; } const auto im2col_matrix_map = - MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims); + MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape); const auto filter_matrix_map = - MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape); auto output_matrix_map = - MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + MapAsMatrixWithLastDimAsRows(output_data, output_shape); Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); - AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, - output_dims, output_activation_min, - output_activation_max); + AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max, + bias_shape, bias_data, output_shape, + output_data); } -inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, - const int8_t* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float* scaling_factors_ptr, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims, - int8_t* im2col_data, const Dims<4>& im2col_dims) { - const int batch_size = input_dims.sizes[3]; - const int filter_width = ArraySize(filter_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + op_params.dilation_width_factor = dilation_width_factor; + op_params.dilation_height_factor = dilation_height_factor; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims), + filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), + output_data, DimsToShape(im2col_dims), im2col_data); +} + +inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr, + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& filter_shape, + const int8_t* filter_data, + const RuntimeShape& bias_shape, const float* bias_data, + const RuntimeShape& output_shape, float* output_data, + const RuntimeShape& im2col_shape, int8_t* im2col_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4); + + const int batch_size = input_shape.Dims(0); + const int filter_width = filter_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); const int8_t* gemm_input_data = nullptr; int num_input; @@ -1970,25 +2080,22 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, TFLITE_DCHECK(im2col_data); // symmetric quantization assumes zero point of 0. const int input_zero_point = 0; - Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, input_zero_point, - im2col_data, im2col_dims); + + Im2col(params, filter_height, filter_width, input_zero_point, input_shape, + input_data, im2col_shape, im2col_data); gemm_input_data = im2col_data; - num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] * - im2col_dims.sizes[2] * im2col_dims.sizes[3]; + num_input = im2col_shape.FlatSize(); } else { TFLITE_DCHECK(!im2col_data); gemm_input_data = input_data; - num_input = input_dims.sizes[0] * input_dims.sizes[1] * - input_dims.sizes[2] * input_dims.sizes[3]; + num_input = input_shape.FlatSize(); } // Flatten 4D matrices into 2D matrices for matrix multiplication. // Flatten so that each filter has its own row. - const int filter_rows = filter_dims.sizes[3]; - const int filter_cols = - filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int filter_rows = filter_shape.Dims(0); + const int filter_cols = FlatSizeSkipDim(filter_shape, 0); // In MatrixBatchVectorMultiplyAccumulate, each output value is the // dot product of one row of the first matrix with one row of the second @@ -1998,15 +2105,14 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, const int gemm_input_cols = filter_cols; const int gemm_input_rows = num_input / gemm_input_cols; - const int output_cols = output_dims.sizes[0]; - const int output_rows = - output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + const int output_cols = output_shape.Dims(3); + const int output_rows = FlatSizeSkipDim(output_shape, 3); TFLITE_DCHECK_EQ(output_cols, filter_rows); TFLITE_DCHECK_EQ(output_rows, gemm_input_rows); - TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols); - TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); - TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); - TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + TFLITE_DCHECK_EQ(bias_shape.Dims(3), output_cols); + TFLITE_DCHECK_EQ(bias_shape.Dims(2), 1); + TFLITE_DCHECK_EQ(bias_shape.Dims(1), 1); + TFLITE_DCHECK_EQ(bias_shape.Dims(0), 1); // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second // input matrix has its own scale factor. This code duplicates the scale @@ -2023,11 +2129,39 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data, /*result_stride=*/1); - AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, - output_dims, output_activation_min, - output_activation_max); + AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max, + bias_shape, bias_data, output_shape, + output_data); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, + const int8_t* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* scaling_factors_ptr, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + int8_t* im2col_data, const Dims<4>& im2col_dims) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims), + input_data, DimsToShape(filter_dims), filter_data, + DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), + output_data, DimsToShape(im2col_dims), im2col_data); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. template <FusedActivationFunctionType Ac> void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, @@ -2045,6 +2179,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims, im2col_dims); } +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void Conv(const float* input_data, const Dims<4>& input_dims, @@ -2061,6 +2196,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims); } +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void Conv(const float* input_data, const Dims<4>& input_dims, @@ -2074,27 +2210,33 @@ void Conv(const float* input_data, const Dims<4>& input_dims, output_dims, im2col_data, im2col_dims); } -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int dilation_width_factor, - int dilation_height_factor, int pad_width, int pad_height, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims, - uint8* im2col_data, const Dims<4>& im2col_dims, - gemmlowp::GemmContext* gemm_context) { +inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& filter_shape, + const uint8* filter_data, const RuntimeShape& bias_shape, + const int32* bias_data, const RuntimeShape& output_shape, + uint8* output_data, const RuntimeShape& im2col_shape, + uint8* im2col_data, gemmlowp::GemmContext* gemm_context) { gemmlowp::ScopedProfilingLabel label("Conv/8bit"); - - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4); const uint8* gemm_input_data = nullptr; - const Dims<4>* gemm_input_dims = nullptr; - const int filter_width = ArraySize(filter_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); + const RuntimeShape* gemm_input_shape = nullptr; + const int filter_width = filter_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || @@ -2104,53 +2246,47 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, const int input_zero_point = -input_offset; TFLITE_DCHECK_GE(input_zero_point, 0); TFLITE_DCHECK_LE(input_zero_point, 255); - DilatedIm2col(input_data, input_dims, filter_dims, stride_width, - stride_height, dilation_width_factor, dilation_height_factor, - pad_width, pad_height, output_dims, input_zero_point, - im2col_data); + DilatedIm2col(params, input_zero_point, input_shape, input_data, + filter_shape, output_shape, im2col_data); gemm_input_data = im2col_data; - gemm_input_dims = &im2col_dims; + gemm_input_shape = &im2col_shape; } else if (need_im2col) { TFLITE_DCHECK(im2col_data); const int input_zero_point = -input_offset; TFLITE_DCHECK_GE(input_zero_point, 0); TFLITE_DCHECK_LE(input_zero_point, 255); - Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, input_zero_point, - im2col_data, im2col_dims); + Im2col(params, filter_height, filter_width, input_zero_point, input_shape, + input_data, im2col_shape, im2col_data); gemm_input_data = im2col_data; - gemm_input_dims = &im2col_dims; + gemm_input_shape = &im2col_shape; } else { TFLITE_DCHECK(!im2col_data); gemm_input_data = input_data; - gemm_input_dims = &input_dims; + gemm_input_shape = &input_shape; } - const int gemm_input_rows = gemm_input_dims->sizes[0]; + const int gemm_input_rows = gemm_input_shape->Dims(3); // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784). // The root cause has not yet been identified though. Same applies below for // the other calls commented out. This is a partial rollback of cl/196819423. - // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0); - const int gemm_input_cols = gemm_input_dims->sizes[1] * - gemm_input_dims->sizes[2] * - gemm_input_dims->sizes[3]; - const int filter_rows = filter_dims.sizes[3]; + // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3); + const int gemm_input_cols = gemm_input_shape->Dims(0) * + gemm_input_shape->Dims(1) * + gemm_input_shape->Dims(2); + const int filter_rows = filter_shape.Dims(0); // See b/79927784. - // const int filter_cols = FlatSizeSkipDim(filter_dims, 3); + // const int filter_cols = FlatSizeSkipDim(filter_shape, 0); const int filter_cols = - filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; - const int output_rows = output_dims.sizes[0]; + filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3); + const int output_rows = output_shape.Dims(3); // See b/79927784. - // const int output_cols = FlatSizeSkipDim(output_dims, 0); + // const int output_cols = FlatSizeSkipDim(output_shape, 3); const int output_cols = - output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2); TFLITE_DCHECK_EQ(output_rows, filter_rows); TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); - TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); - TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); - TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); - TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix( filter_data, filter_rows, filter_cols); gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix( @@ -2166,6 +2302,43 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, input_offset, output_pipeline); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + uint8* im2col_data, const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + op_params.dilation_width_factor = dilation_width_factor; + op_params.dilation_height_factor = dilation_height_factor; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + + Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims), + filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims), + output_data, DimsToShape(im2col_dims), im2col_data, gemm_context); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. inline void Conv(const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, @@ -2184,6 +2357,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims, gemm_context); } +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> inline void Conv(const uint8* input_data, const Dims<4>& input_dims, @@ -2213,6 +2387,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims, gemm_context); } +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void Conv(const uint8* input_data, const Dims<4>& input_dims, @@ -2236,13 +2411,14 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims, gemm_context); } +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac, typename T> void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, int pad_width, int pad_height, int kheight, int kwidth, - uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + uint8 zero_byte, T* output_data, const Dims<4>& output_dims) { Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, - kwidth, byte_zero, output_data, output_dims); + kwidth, zero_byte, output_data, output_dims); } // legacy, for compatibility with old checked-in code @@ -2266,6 +2442,7 @@ void ConvAsGemm(const float* input_data, const Dims<4>& input_dims, output_dims); } +// TODO(b/80418076): Move to legacy ops file, update invocations. // legacy, for compatibility with old checked-in code template <FusedActivationFunctionType Ac> void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, @@ -5832,58 +6009,45 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data, } template <typename T> -void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, - const Dims<4>& filter_dims, int stride_width, - int stride_height, int pad_width, int pad_height, - const Dims<4>& output_dims, uint8 zero_byte, - T* im2col_data) { +void TransposeIm2col(const ConvParams& params, uint8 zero_byte, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& filter_shape, + const RuntimeShape& output_shape, T* im2col_data) { gemmlowp::ScopedProfilingLabel label("TransposeIm2col"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); TFLITE_DCHECK(im2col_data); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 0); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + MatchingDim(output_shape, 3, filter_shape, 3); // output_depth // Construct the MxN sized im2col matrix. // The rows M, are sub-ordered B x H x W - Dims<4> row_dims; - row_dims.sizes[0] = output_width; - row_dims.sizes[1] = output_height; - row_dims.sizes[2] = batches; - row_dims.sizes[3] = 1; - ComputeStrides(&row_dims); - + const RuntimeShape row_shape({1, batches, output_height, output_width}); // The columns, N, are sub-ordered Kh x Kw x Din - Dims<4> col_dims; - col_dims.sizes[0] = input_depth; - col_dims.sizes[1] = filter_width; - col_dims.sizes[2] = filter_height; - col_dims.sizes[3] = 1; - ComputeStrides(&col_dims); - + const RuntimeShape col_shape({1, filter_height, filter_width, input_depth}); // Use dimensions M and N to construct dims for indexing directly into im2col - Dims<4> im2col_dims; - im2col_dims.sizes[0] = FlatSize(col_dims); - im2col_dims.sizes[1] = FlatSize(row_dims); - im2col_dims.sizes[2] = 1; - im2col_dims.sizes[3] = 1; - ComputeStrides(&im2col_dims); + const RuntimeShape im2col_shape( + {1, 1, row_shape.FlatSize(), col_shape.FlatSize()}); // Build the im2col matrix by looping through all the input pixels, // computing their influence on the output, rather than looping through all // the output pixels. We therefore must initialize the im2col array to zero. // This is potentially inefficient because we subsequently overwrite bytes // set here. However, in practice memset is very fast and costs negligible. - memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T)); + memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T)); // Loop through the output batches for (int batch = 0; batch < batches; ++batch) { @@ -5903,11 +6067,11 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, if ((out_x >= 0) && (out_x < output_width)) { // Copy the input elements of this pixel T const* src = - input_data + Offset(input_dims, 0, in_x, in_y, batch); + input_data + Offset(input_shape, batch, in_y, in_x, 0); + int row_offset = Offset(row_shape, 0, batch, out_y, out_x); + int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0); T* dst = im2col_data + - Offset(im2col_dims, - Offset(col_dims, 0, filter_x, filter_y, 0), - Offset(row_dims, out_x, out_y, batch, 0), 0, 0); + Offset(im2col_shape, 0, 0, row_offset, col_offset); memcpy(dst, src, input_depth * sizeof(T)); } } @@ -5918,31 +6082,71 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, } } -inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T> +void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 zero_byte, + T* im2col_data) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + + TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data, + DimsToShape(filter_dims), DimsToShape(output_dims), + im2col_data); +} + +inline void TransposeConv( + const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) { gemmlowp::ScopedProfilingLabel label("TransposeConv"); // Note we could use transposed weights with forward conv for unstrided // cases. But we are already getting good performance with this code as-is. TFLITE_DCHECK(im2col_data); - TransposeIm2col(input_data, input_dims, filter_dims, stride_width, - stride_height, pad_width, pad_height, output_dims, 0, - im2col_data); + TransposeIm2col(params, 0, input_shape, input_data, filter_shape, + output_shape, im2col_data); const auto im2col_matrix_map = - MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims); + MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape); const auto filter_matrix_map = - MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape); auto output_matrix_map = - MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + MapAsMatrixWithLastDimAsRows(output_data, output_shape); Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + tflite::ConvParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + + TransposeConv(op_params, DimsToShape(input_dims), input_data, + DimsToShape(filter_dims), filter_data, DimsToShape(output_dims), + output_data, DimsToShape(im2col_dims), im2col_data); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index c4c7cf3842..023707d466 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -26,8 +26,8 @@ enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; enum class PaddingType : uint8 { kNone, kSame, kValid }; struct PaddingValues { - int8 width; - int8 height; + int16 width; + int16 height; }; // This enumeration allows for non-default formats for the weights array @@ -734,10 +734,10 @@ struct ConvParams { PaddingType padding_type; PaddingValues padding_values; // TODO(starka): This was just "stride", so check that width+height is OK. - int8 stride_width; - int8 stride_height; - int8 dilation_width_factor; - int8 dilation_height_factor; + int16 stride_width; + int16 stride_height; + int16 dilation_width_factor; + int16 dilation_height_factor; // uint8 inference params. // TODO(b/65838351): Use smaller types if appropriate. int32 input_offset; @@ -745,8 +745,12 @@ struct ConvParams { int32 output_offset; int32 output_multiplier; int output_shift; - int32 output_activation_min; - int32 output_activation_max; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; }; struct DepthToSpaceParams { @@ -756,8 +760,8 @@ struct DepthToSpaceParams { struct DepthwiseParams { PaddingType padding_type; PaddingValues padding_values; - int8 stride; - int8 depth_multiplier; + int16 stride; + int16 depth_multiplier; // uint8 inference params. // TODO(b/65838351): Use smaller types if appropriate. int32 input_offset; @@ -765,8 +769,12 @@ struct DepthwiseParams { int32 output_offset; int32 output_multiplier; int output_shift; - int32 output_activation_min; - int32 output_activation_max; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; }; struct DequantizationParams { @@ -787,13 +795,17 @@ struct FullyConnectedParams { int32 output_offset; int32 output_multiplier; int output_shift; - int32 output_activation_min; - int32 output_activation_max; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; FullyConnectedWeightsFormat weights_format; }; struct GatherParams { - int8 input_rank; + int16 input_rank; int16 axis; }; |