diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 14:04:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 14:18:07 -0700 |
commit | cc83067469bc30bba55932c587f31ef68f15792f (patch) | |
tree | fe201c147a72a751a3ba050b6687c8b41b14b42f /tensorflow/contrib/lite/kernels | |
parent | 2fb9377a5ec610b8eff853fd1d2d53eabf711eda (diff) |
Migrate a few conv kernels to use new kernel signatures.
PiperOrigin-RevId: 214831837
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
3 files changed, 100 insertions, 84 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 101b4fc961..dbcadbee14 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -86,6 +86,18 @@ struct OpData { bool run_multithreaded_kernel; }; +inline PaddingType RuntimePaddingType(TfLitePadding padding) { + switch (padding) { + case TfLitePadding::kTfLitePaddingSame: + return PaddingType::kSame; + case TfLitePadding::kTfLitePaddingValid: + return PaddingType::kValid; + case TfLitePadding::kTfLitePaddingUnknown: + default: + return PaddingType::kNone; + } +} + void* Init(TfLiteContext* context, const char* buffer, size_t length) { // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to use as scratch space for im2col, and @@ -487,18 +499,18 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, } else { effective_kernel_type = kernel_type; } + ConvParams op_params; + op_params.padding_type = RuntimePaddingType(params->padding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; switch (effective_kernel_type) { case kReference: { - ConvParams op_params; - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; reference_ops::Conv(op_params, GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(filter), GetTensorData<float>(filter), GetTensorShape(bias), @@ -508,16 +520,6 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, break; } case kGenericOptimized: { - ConvParams op_params; - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; optimized_ops::Conv(op_params, GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(filter), GetTensorData<float>(filter), GetTensorShape(bias), @@ -534,25 +536,21 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, filter_data = GetTensorData<float>(filter); } multithreaded_ops::Conv( - *eigen_support::GetThreadPoolDevice(context), - GetTensorData<float>(input), GetTensorDims(input), filter_data, - GetTensorDims(filter), GetTensorData<float>(bias), - GetTensorDims(bias), params->stride_width, params->stride_height, - data->padding.width, data->padding.height, params->padding, - output_activation_min, output_activation_max, - GetTensorData<float>(output), GetTensorDims(output), - GetTensorData<float>(im2col), GetTensorDims(im2col)); + *eigen_support::GetThreadPoolDevice(context), op_params, + GetTensorShape(input), GetTensorData<float>(input), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData<float>(bias), GetTensorShape(output), + GetTensorData<float>(output), GetTensorShape(im2col), + GetTensorData<float>(im2col)); break; } case kCblasOptimized: { - cblas_ops::Conv(GetTensorData<float>(input), GetTensorDims(input), - GetTensorData<float>(filter), GetTensorDims(filter), - GetTensorData<float>(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, - data->padding.width, data->padding.height, - output_activation_min, output_activation_max, - GetTensorData<float>(output), GetTensorDims(output), - GetTensorData<float>(im2col), GetTensorDims(im2col)); + cblas_ops::Conv(op_params, GetTensorShape(input), + GetTensorData<float>(input), GetTensorShape(filter), + GetTensorData<float>(filter), GetTensorShape(bias), + GetTensorData<float>(bias), GetTensorShape(output), + GetTensorData<float>(output), GetTensorShape(im2col), + GetTensorData<float>(im2col)); break; } } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h index 40d42bbae9..2d96da65c3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h @@ -31,20 +31,29 @@ limitations under the License. namespace tflite { namespace cblas_ops { -inline void Conv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { +inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, + float* im2col_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const int dilation_width_factor = params.dilation_width_factor; + const int dilation_height_factor = params.dilation_height_factor; + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); gemmlowp::ScopedProfilingLabel label("Conv/cblas"); const float* gemm_input_data = nullptr; - const Dims<4>* gemm_input_dims = nullptr; - const int filter_width = ArraySize(filter_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); + const RuntimeShape* gemm_input_shape = nullptr; + const int filter_width = filter_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; if (need_im2col) { @@ -55,18 +64,17 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, op_params.padding_values.height = pad_height; op_params.stride_width = stride_width; op_params.stride_height = stride_height; - op_params.dilation_width_factor = 1; - op_params.dilation_height_factor = 1; + op_params.dilation_width_factor = dilation_width_factor; + op_params.dilation_height_factor = dilation_height_factor; optimized_ops::Im2col(op_params, filter_height, filter_width, 0, - DimsToShape(input_dims), input_data, - DimsToShape(im2col_dims), im2col_data); + input_shape, input_data, im2col_shape, im2col_data); gemm_input_data = im2col_data; - gemm_input_dims = &im2col_dims; + gemm_input_shape = &im2col_shape; } else { TFLITE_DCHECK(!im2col_data); gemm_input_data = input_data; - gemm_input_dims = &input_dims; + gemm_input_shape = &input_shape; } // The following code computes matrix multiplication c = a * transponse(b) @@ -78,10 +86,10 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* a = gemm_input_data; const float* b = filter_data; float* c = output_data; - int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] * - gemm_input_dims->sizes[3]; - int n = output_dims.sizes[0]; - int k = gemm_input_dims->sizes[0]; + const int gemm_input_dims = gemm_input_shape->DimensionsCount(); + int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1); + int n = output_shape.Dims(3); + int k = gemm_input_shape->Dims(gemm_input_dims - 1); // The stride of matrix a, b and c respectively. int stride_a = k; int stride_b = k; @@ -91,8 +99,8 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, stride_a, b, stride_b, 0.0f, c, stride_c); optimized_ops::AddBiasAndEvalActivationFunction( - output_activation_min, output_activation_max, DimsToShape(bias_dims), - bias_data, DimsToShape(output_dims), output_data); + output_activation_min, output_activation_max, bias_shape, bias_data, + output_shape, output_data); } } // namespace cblas_ops diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index b5d001cc9e..4139cf4eba 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -69,13 +69,13 @@ struct MatMulConvFunctor { template <class T> class EigenTensorConvFunctor { private: - Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) { + Eigen::PaddingType RuntimePadding2EigenPadding(PaddingType padding) { switch (padding) { - case kTfLitePaddingValid: + case PaddingType::kValid: return Eigen::PADDING_VALID; - case kTfLitePaddingSame: + case PaddingType::kSame: return Eigen::PADDING_SAME; - case kTfLitePaddingUnknown: + case PaddingType::kNone: assert(false); // should never get here. return Eigen::PADDING_VALID; } @@ -89,7 +89,7 @@ class EigenTensorConvFunctor { int input_width, int input_depth, const T* filter_data, int filter_height, int filter_width, int filter_count, int stride_rows, int stride_cols, int pad_width, - int pad_height, TfLitePadding padding, T* output_data, + int pad_height, PaddingType padding, T* output_data, int output_height, int output_width) { const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && stride_rows == 1 && stride_cols == 1); @@ -127,28 +127,38 @@ class EigenTensorConvFunctor { input_depth, filter_count); output.device(device) = Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows, - TfLitePadding2EigenPadding(padding)); + RuntimePadding2EigenPadding(padding)); } } }; -inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data, - const Dims<4>& input_dims, const float* filter_data, - const Dims<4>& filter_dims, const float* bias_data, - const Dims<4>& bias_dims, int stride_width, int stride_height, - int pad_width, int pad_height, TfLitePadding padding, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims, - float* im2col_data, const Dims<4>& im2col_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); +inline void Conv(const Eigen::ThreadPoolDevice& device, + const ConvParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape, + float* output_data, const RuntimeShape& im2col_shape, + float* im2col_data) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const PaddingType padding = params.padding_type; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); EigenTensorConvFunctor<float> conv_functor; conv_functor(device, input_data, im2col_data, batches, input_height, input_width, input_depth, filter_data, filter_height, @@ -157,8 +167,8 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data, output_width); optimized_ops::AddBiasAndEvalActivationFunction( - output_activation_min, output_activation_max, DimsToShape(bias_dims), - bias_data, DimsToShape(output_dims), output_data); + output_activation_min, output_activation_max, bias_shape, bias_data, + output_shape, output_data); } } // namespace multithreaded_ops |