aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-13 09:35:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 09:42:05 -0700
commit5ae1c93473ae690d4a7b9389b1219179cb2504a3 (patch)
treebfafbd3138b0c56e2dea1dc23947b9742e241d04 /tensorflow/contrib/lite/kernels/internal
parent88a7c5b98fc1ccb56134003ba3dc88a09385c0a7 (diff)
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 212826308
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h688
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h42
2 files changed, 473 insertions, 257 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 2c8e8f90e3..baed8f4993 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -260,16 +260,16 @@ inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
return true;
}
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims,
- float output_activation_min,
- float output_activation_max) {
+inline void AddBiasAndEvalActivationFunction(float output_activation_min,
+ float output_activation_max,
+ const RuntimeShape& bias_shape,
+ const float* bias_data,
+ const RuntimeShape& array_shape,
+ float* array_data) {
#ifdef USE_NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
float* array_ptr = array_data;
float* array_end_ptr = array_ptr + array_size;
@@ -319,8 +319,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
}
#else // not NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
for (int array_offset = 0; array_offset < array_size;
array_offset += bias_size) {
@@ -333,6 +333,19 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
#endif
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(array_dims), array_data);
+}
+
// Note: This to be converted to RuntimeShapes along with Conv.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
@@ -1672,12 +1685,16 @@ inline void ShuffledFullyConnected(
}
template <typename T>
-inline void ExtractPatchIntoBufferColumn(
- const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int in_width, int in_height, int in_depth, int single_buffer_length,
- int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
+ int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height,
+ int pad_width, int pad_height,
+ int in_width, int in_height,
+ int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data,
+ T* conv_buffer_data, uint8 zero_byte) {
gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
// This chunk of code reshapes all the inputs corresponding to
// output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
const int kwidth_times_indepth = kwidth * in_depth;
@@ -1699,7 +1716,7 @@ inline void ExtractPatchIntoBufferColumn(
const int output_row_offset = (buffer_id * single_buffer_length);
int out_offset =
output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
- int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+ int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
// Express all of the calculations as padding around the input patch.
const int top_padding = h_offset;
@@ -1713,7 +1730,7 @@ inline void ExtractPatchIntoBufferColumn(
// patch that are off the edge of the input image.
if (top_padding > 0) {
const int top_row_elements = (top_padding * kwidth * in_depth);
- memset(conv_buffer_data + output_row_offset, byte_zero,
+ memset(conv_buffer_data + output_row_offset, zero_byte,
(top_row_elements * sizeof(T)));
}
@@ -1730,14 +1747,14 @@ inline void ExtractPatchIntoBufferColumn(
for (int ih = ih_start; ih < ih_end; ++ih) {
if (left_padding > 0) {
const int left_start = (out_offset - (left_padding * in_depth));
- memset(conv_buffer_data + left_start, byte_zero,
+ memset(conv_buffer_data + left_start, zero_byte,
(left_padding * in_depth * sizeof(T)));
}
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
single_row_num * sizeof(T));
if (right_padding > 0) {
const int right_start = (out_offset + single_row_num);
- memset(conv_buffer_data + right_start, byte_zero,
+ memset(conv_buffer_data + right_start, zero_byte,
(right_padding * in_depth * sizeof(T)));
}
out_offset += kwidth_times_indepth;
@@ -1752,61 +1769,64 @@ inline void ExtractPatchIntoBufferColumn(
const int bottom_start =
output_row_offset +
((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
- memset(conv_buffer_data + bottom_start, byte_zero,
+ memset(conv_buffer_data + bottom_start, zero_byte,
(bottom_row_elements * sizeof(T)));
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T>
-void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 byte_zero,
- T* im2col_data) {
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+ ExtractPatchIntoBufferColumn(
+ DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+ stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+ single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
// For dilated convolution, the input pixels are not contiguous therefore we
// can't use the same opitimizations as Im2Col(). Though note this code would
// work fine for the non-dilated case too (though likely a bit slower).
gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
TFLITE_DCHECK(im2col_data);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- MatchingArraySize(output_dims, 0, filter_dims, 3);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 0);
// Construct the MxN sized im2col matrix.
// The rows M, are sub-ordered B x H x W
- Dims<4> row_dims;
- row_dims.sizes[0] = output_width;
- row_dims.sizes[1] = output_height;
- row_dims.sizes[2] = batches;
- row_dims.sizes[3] = 1;
- ComputeStrides(&row_dims);
-
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
// The columns, N, are sub-ordered Kh x Kw x Din
- Dims<4> col_dims;
- col_dims.sizes[0] = input_depth;
- col_dims.sizes[1] = filter_width;
- col_dims.sizes[2] = filter_height;
- col_dims.sizes[3] = 1;
- ComputeStrides(&col_dims);
-
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
// Use dimensions M and N to construct dims for indexing directly into im2col
- Dims<4> im2col_dims;
- im2col_dims.sizes[0] = FlatSize(col_dims);
- im2col_dims.sizes[1] = FlatSize(row_dims);
- im2col_dims.sizes[2] = 1;
- im2col_dims.sizes[3] = 1;
- ComputeStrides(&im2col_dims);
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
// Loop through the output rows (B x H x W)
for (int batch = 0; batch < batches; ++batch) {
@@ -1814,7 +1834,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
for (int out_x = 0; out_x < output_width; ++out_x) {
// Each im2col row is an output pixel. Arrange the input data in this
// row in an order we can conveniently multiply with the filter data.
- int row_offset = Offset(row_dims, out_x, out_y, batch, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
// Loop through all the pixels of the filter (Kh x Kw)
@@ -1825,25 +1845,25 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
// Loop through all the filter pixels in this row.
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
- int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
T* dst = im2col_data +
- Offset(im2col_dims, col_offset, row_offset, 0, 0);
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
if ((in_x >= 0) && (in_x < input_width)) {
// Filter pixel is within the input, copy the input data.
T const* src =
- input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
memcpy(dst, src, input_depth * sizeof(T));
} else {
// Filter pixel is outside the input, zero it out.
- memset(dst, byte_zero, input_depth * sizeof(T));
+ memset(dst, zero_byte, input_depth * sizeof(T));
}
}
} else {
// Filter row is outside the input, zero out the entire filter row.
- int col_offset = Offset(col_dims, 0, 0, filter_y, 0);
- T* dst =
- im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0);
- memset(dst, byte_zero, filter_width * input_depth * sizeof(T));
+ int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
+ T* dst = im2col_data +
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
+ memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
}
}
}
@@ -1851,21 +1871,49 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int kheight,
- int kwidth, uint8 byte_zero, T* output_data,
- const Dims<4>& output_dims) {
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+
+ DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+template <typename T>
+void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Im2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
int buffer_id = 0;
// Loop over the output nodes.
@@ -1873,93 +1921,155 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
ExtractPatchIntoBufferColumn(
- input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+ input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
pad_width, pad_height, input_width, input_height, input_depth,
- output_depth, buffer_id, input_data, output_data, byte_zero);
+ output_depth, buffer_id, input_data, output_data, zero_byte);
++buffer_id;
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 zero_byte, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+
+ Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+ input_data, DimsToShape(output_dims), output_data);
+}
+
// legacy, for compatibility with old checked-in code
template <typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
+ kwidth, zero_byte, output_data, output_dims);
}
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
(void)im2col_data;
- (void)im2col_dims;
+ (void)im2col_shape;
gemmlowp::ScopedProfilingLabel label("Conv");
// NB: static_cast<float>(0x00000000h) == 0.0f
const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_dilated_im2col) {
- DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, dilation_width_factor, dilation_height_factor,
- pad_width, pad_height, output_dims, float_zero_byte,
- im2col_data);
+ DilatedIm2col(params, float_zero_byte, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, float_zero_byte,
- im2col_data, im2col_dims);
+ Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
// TODO(aselle): We need to make sure to not send im2col if it is not
// needed.
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
+ MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
}
-inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
- const int8_t* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* scaling_factors_ptr,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- int8_t* im2col_data, const Dims<4>& im2col_dims) {
- const int batch_size = input_dims.sizes[3];
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
+ const RuntimeShape& input_shape,
+ const int8_t* input_data,
+ const RuntimeShape& filter_shape,
+ const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const float* bias_data,
+ const RuntimeShape& output_shape, float* output_data,
+ const RuntimeShape& im2col_shape, int8_t* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
+
+ const int batch_size = input_shape.Dims(0);
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const int8_t* gemm_input_data = nullptr;
int num_input;
@@ -1970,25 +2080,22 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK(im2col_data);
// symmetric quantization assumes zero point of 0.
const int input_zero_point = 0;
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] *
- im2col_dims.sizes[2] * im2col_dims.sizes[3];
+ num_input = im2col_shape.FlatSize();
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- num_input = input_dims.sizes[0] * input_dims.sizes[1] *
- input_dims.sizes[2] * input_dims.sizes[3];
+ num_input = input_shape.FlatSize();
}
// Flatten 4D matrices into 2D matrices for matrix multiplication.
// Flatten so that each filter has its own row.
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols =
- filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int filter_rows = filter_shape.Dims(0);
+ const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
// In MatrixBatchVectorMultiplyAccumulate, each output value is the
// dot product of one row of the first matrix with one row of the second
@@ -1998,15 +2105,14 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
const int gemm_input_cols = filter_cols;
const int gemm_input_rows = num_input / gemm_input_cols;
- const int output_cols = output_dims.sizes[0];
- const int output_rows =
- output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ const int output_cols = output_shape.Dims(3);
+ const int output_rows = FlatSizeSkipDim(output_shape, 3);
TFLITE_DCHECK_EQ(output_cols, filter_rows);
TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(3), output_cols);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(2), 1);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(1), 1);
+ TFLITE_DCHECK_EQ(bias_shape.Dims(0), 1);
// MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
// input matrix has its own scale factor. This code duplicates the scale
@@ -2023,11 +2129,39 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
/*result_stride=*/1);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+ input_data, DimsToShape(filter_dims), filter_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
@@ -2045,6 +2179,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2061,6 +2196,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2074,27 +2210,33 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
output_dims, im2col_data, im2col_dims);
}
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
-
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
const uint8* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
@@ -2104,53 +2246,47 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, dilation_width_factor, dilation_height_factor,
- pad_width, pad_height, output_dims, input_zero_point,
- im2col_data);
+ DilatedIm2col(params, input_zero_point, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
- const int gemm_input_rows = gemm_input_dims->sizes[0];
+ const int gemm_input_rows = gemm_input_shape->Dims(3);
// Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
// The root cause has not yet been identified though. Same applies below for
// the other calls commented out. This is a partial rollback of cl/196819423.
- // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
- const int gemm_input_cols = gemm_input_dims->sizes[1] *
- gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- const int filter_rows = filter_dims.sizes[3];
+ // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
+ const int gemm_input_cols = gemm_input_shape->Dims(0) *
+ gemm_input_shape->Dims(1) *
+ gemm_input_shape->Dims(2);
+ const int filter_rows = filter_shape.Dims(0);
// See b/79927784.
- // const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
const int filter_cols =
- filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
- const int output_rows = output_dims.sizes[0];
+ filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
+ const int output_rows = output_shape.Dims(3);
// See b/79927784.
- // const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ // const int output_cols = FlatSizeSkipDim(output_shape, 3);
const int output_cols =
- output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, filter_rows, filter_cols);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
@@ -2166,6 +2302,43 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
@@ -2184,6 +2357,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2213,6 +2387,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2236,13 +2411,14 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac, typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
+ kwidth, zero_byte, output_data, output_dims);
}
// legacy, for compatibility with old checked-in code
@@ -2266,6 +2442,7 @@ void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
@@ -5832,58 +6009,45 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
+void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK(im2col_data);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 0);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 3); // output_depth
// Construct the MxN sized im2col matrix.
// The rows M, are sub-ordered B x H x W
- Dims<4> row_dims;
- row_dims.sizes[0] = output_width;
- row_dims.sizes[1] = output_height;
- row_dims.sizes[2] = batches;
- row_dims.sizes[3] = 1;
- ComputeStrides(&row_dims);
-
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
// The columns, N, are sub-ordered Kh x Kw x Din
- Dims<4> col_dims;
- col_dims.sizes[0] = input_depth;
- col_dims.sizes[1] = filter_width;
- col_dims.sizes[2] = filter_height;
- col_dims.sizes[3] = 1;
- ComputeStrides(&col_dims);
-
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
// Use dimensions M and N to construct dims for indexing directly into im2col
- Dims<4> im2col_dims;
- im2col_dims.sizes[0] = FlatSize(col_dims);
- im2col_dims.sizes[1] = FlatSize(row_dims);
- im2col_dims.sizes[2] = 1;
- im2col_dims.sizes[3] = 1;
- ComputeStrides(&im2col_dims);
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
// Build the im2col matrix by looping through all the input pixels,
// computing their influence on the output, rather than looping through all
// the output pixels. We therefore must initialize the im2col array to zero.
// This is potentially inefficient because we subsequently overwrite bytes
// set here. However, in practice memset is very fast and costs negligible.
- memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T));
+ memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
// Loop through the output batches
for (int batch = 0; batch < batches; ++batch) {
@@ -5903,11 +6067,11 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
if ((out_x >= 0) && (out_x < output_width)) {
// Copy the input elements of this pixel
T const* src =
- input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
T* dst = im2col_data +
- Offset(im2col_dims,
- Offset(col_dims, 0, filter_x, filter_y, 0),
- Offset(row_dims, out_x, out_y, batch, 0), 0, 0);
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
memcpy(dst, src, input_depth * sizeof(T));
}
}
@@ -5918,31 +6082,71 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
- TransposeIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, pad_width, pad_height, output_dims, 0,
- im2col_data);
+ TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
+ output_shape, im2col_data);
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims);
+ MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index c4c7cf3842..023707d466 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -26,8 +26,8 @@ enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
enum class PaddingType : uint8 { kNone, kSame, kValid };
struct PaddingValues {
- int8 width;
- int8 height;
+ int16 width;
+ int16 height;
};
// This enumeration allows for non-default formats for the weights array
@@ -734,10 +734,10 @@ struct ConvParams {
PaddingType padding_type;
PaddingValues padding_values;
// TODO(starka): This was just "stride", so check that width+height is OK.
- int8 stride_width;
- int8 stride_height;
- int8 dilation_width_factor;
- int8 dilation_height_factor;
+ int16 stride_width;
+ int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
int32 input_offset;
@@ -745,8 +745,12 @@ struct ConvParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
};
struct DepthToSpaceParams {
@@ -756,8 +760,8 @@ struct DepthToSpaceParams {
struct DepthwiseParams {
PaddingType padding_type;
PaddingValues padding_values;
- int8 stride;
- int8 depth_multiplier;
+ int16 stride;
+ int16 depth_multiplier;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
int32 input_offset;
@@ -765,8 +769,12 @@ struct DepthwiseParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
};
struct DequantizationParams {
@@ -787,13 +795,17 @@ struct FullyConnectedParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
FullyConnectedWeightsFormat weights_format;
};
struct GatherParams {
- int8 input_rank;
+ int16 input_rank;
int16 axis;
};