aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 14:04:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 14:18:07 -0700
commitcc83067469bc30bba55932c587f31ef68f15792f (patch)
treefe201c147a72a751a3ba050b6687c8b41b14b42f /tensorflow/contrib/lite/kernels
parent2fb9377a5ec610b8eff853fd1d2d53eabf711eda (diff)
Migrate a few conv kernels to use new kernel signatures.
PiperOrigin-RevId: 214831837
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h54
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h60
3 files changed, 100 insertions, 84 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 101b4fc961..dbcadbee14 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -86,6 +86,18 @@ struct OpData {
bool run_multithreaded_kernel;
};
+inline PaddingType RuntimePaddingType(TfLitePadding padding) {
+ switch (padding) {
+ case TfLitePadding::kTfLitePaddingSame:
+ return PaddingType::kSame;
+ case TfLitePadding::kTfLitePaddingValid:
+ return PaddingType::kValid;
+ case TfLitePadding::kTfLitePaddingUnknown:
+ default:
+ return PaddingType::kNone;
+ }
+}
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to use as scratch space for im2col, and
@@ -487,18 +499,18 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
} else {
effective_kernel_type = kernel_type;
}
+ ConvParams op_params;
+ op_params.padding_type = RuntimePaddingType(params->padding);
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
switch (effective_kernel_type) {
case kReference: {
- ConvParams op_params;
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
reference_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
@@ -508,16 +520,6 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
break;
}
case kGenericOptimized: {
- ConvParams op_params;
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
optimized_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
@@ -534,25 +536,21 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
filter_data = GetTensorData<float>(filter);
}
multithreaded_ops::Conv(
- *eigen_support::GetThreadPoolDevice(context),
- GetTensorData<float>(input), GetTensorDims(input), filter_data,
- GetTensorDims(filter), GetTensorData<float>(bias),
- GetTensorDims(bias), params->stride_width, params->stride_height,
- data->padding.width, data->padding.height, params->padding,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ *eigen_support::GetThreadPoolDevice(context), op_params,
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), filter_data, GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kCblasOptimized: {
- cblas_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ cblas_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
index 40d42bbae9..2d96da65c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
@@ -31,20 +31,29 @@ limitations under the License.
namespace tflite {
namespace cblas_ops {
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
gemmlowp::ScopedProfilingLabel label("Conv/cblas");
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_im2col) {
@@ -55,18 +64,17 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
op_params.padding_values.height = pad_height;
op_params.stride_width = stride_width;
op_params.stride_height = stride_height;
- op_params.dilation_width_factor = 1;
- op_params.dilation_height_factor = 1;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
optimized_ops::Im2col(op_params, filter_height, filter_width, 0,
- DimsToShape(input_dims), input_data,
- DimsToShape(im2col_dims), im2col_data);
+ input_shape, input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
// The following code computes matrix multiplication c = a * transponse(b)
@@ -78,10 +86,10 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* a = gemm_input_data;
const float* b = filter_data;
float* c = output_data;
- int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- int n = output_dims.sizes[0];
- int k = gemm_input_dims->sizes[0];
+ const int gemm_input_dims = gemm_input_shape->DimensionsCount();
+ int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
+ int n = output_shape.Dims(3);
+ int k = gemm_input_shape->Dims(gemm_input_dims - 1);
// The stride of matrix a, b and c respectively.
int stride_a = k;
int stride_b = k;
@@ -91,8 +99,8 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
stride_a, b, stride_b, 0.0f, c, stride_c);
optimized_ops::AddBiasAndEvalActivationFunction(
- output_activation_min, output_activation_max, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace cblas_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index b5d001cc9e..4139cf4eba 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -69,13 +69,13 @@ struct MatMulConvFunctor {
template <class T>
class EigenTensorConvFunctor {
private:
- Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) {
+ Eigen::PaddingType RuntimePadding2EigenPadding(PaddingType padding) {
switch (padding) {
- case kTfLitePaddingValid:
+ case PaddingType::kValid:
return Eigen::PADDING_VALID;
- case kTfLitePaddingSame:
+ case PaddingType::kSame:
return Eigen::PADDING_SAME;
- case kTfLitePaddingUnknown:
+ case PaddingType::kNone:
assert(false); // should never get here.
return Eigen::PADDING_VALID;
}
@@ -89,7 +89,7 @@ class EigenTensorConvFunctor {
int input_width, int input_depth, const T* filter_data,
int filter_height, int filter_width, int filter_count,
int stride_rows, int stride_cols, int pad_width,
- int pad_height, TfLitePadding padding, T* output_data,
+ int pad_height, PaddingType padding, T* output_data,
int output_height, int output_width) {
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
stride_rows == 1 && stride_cols == 1);
@@ -127,28 +127,38 @@ class EigenTensorConvFunctor {
input_depth, filter_count);
output.device(device) =
Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
- TfLitePadding2EigenPadding(padding));
+ RuntimePadding2EigenPadding(padding));
}
}
};
-inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
- const Dims<4>& input_dims, const float* filter_data,
- const Dims<4>& filter_dims, const float* bias_data,
- const Dims<4>& bias_dims, int stride_width, int stride_height,
- int pad_width, int pad_height, TfLitePadding padding,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void Conv(const Eigen::ThreadPoolDevice& device,
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const PaddingType padding = params.padding_type;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
EigenTensorConvFunctor<float> conv_functor;
conv_functor(device, input_data, im2col_data, batches, input_height,
input_width, input_depth, filter_data, filter_height,
@@ -157,8 +167,8 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
- output_activation_min, output_activation_max, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace multithreaded_ops