aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-17 15:32:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 15:41:28 -0700
commit3365cd1cc7bf3dcb781c76652132119bf82133e6 (patch)
tree6ab3cde6c9032c5f71e907f559c8c0d4287e92fa
parentaec9a7077001e8eacb278839f2e56c228afdc4a4 (diff)
Add generic fallback optimized implementations for dilated DepthwiseConv.
PiperOrigin-RevId: 213350122
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc75
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc15
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h52
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h68
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.h3
10 files changed, 281 insertions, 145 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index f52d29ea76..daaf6714cc 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -509,6 +509,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 3e1ce60113..798ee849ec 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -184,17 +184,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const Dims<4>&, const float*, const Dims<4>&, int, int,
int, int, int, int, int, float, float, float*,
const Dims<4>&);
- KernelType effective_kernel_type;
- // TODO(suharshs): Currently only the reference implementation supports
- // dilations.
- if ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1)) {
- effective_kernel_type = kReference;
- } else {
- effective_kernel_type = kernel_type;
- }
-
- if (effective_kernel_type == kReference) {
+ if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -224,17 +214,7 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int, int, int, int, int, int, int, int32, int32, int,
int32, int32, uint8*, const Dims<4>&);
- KernelType effective_kernel_type;
- // TODO(suharshs): Currently only the reference implementation supports
- // dilations.
- if ((params->dilation_width_factor != 1) ||
- (params->dilation_height_factor != 1)) {
- effective_kernel_type = kReference;
- } else {
- effective_kernel_type = kernel_type;
- }
-
- if (effective_kernel_type == kReference) {
+ if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index 2af26ab80a..4a33a0319d 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -14,12 +14,24 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAreArray;
@@ -28,9 +40,11 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
public:
// TODO(ahentz): Also test different activation types, bias, padding types,
// stride values.
- BaseDepthwiseConvolutionOpModel(const TensorData& input,
+ BaseDepthwiseConvolutionOpModel(TfLiteRegistration* registration,
+ const TensorData& input,
const TensorData& filter,
const TensorData& output,
+ Padding padding_type,
int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -56,11 +70,14 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
SetBuiltinOp(
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
+ CreateDepthwiseConv2DOptions(builder_, padding_type, 1, 1, depth_mul,
ActivationFunctionType_NONE,
dilation_factor, dilation_factor)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_DEPTHWISE_CONV_2D, registration);
+
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
@@ -86,10 +103,25 @@ class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-TEST(DepthwiseConvolutionOpTest, SimpleTest) {
- DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_DEPTHWISE_CONVOLUTION_REF()},
+ {"GenericOptimized",
+ ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT()},
+ {"NeonOptimized", ops::builtin::Register_DEPTHWISE_CONVOLUTION_NEON_OPT()},
+});
+
+class DepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleTest) {
+ DepthwiseConvolutionOpModel m(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -112,7 +144,7 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
-TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
const int depth = 1;
const int image_width = 9;
const int image_height = 9;
@@ -121,10 +153,11 @@ TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
const int filter_count = 1;
const int dilation_factor = 3;
DepthwiseConvolutionOpModel m(
+ GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
- {TensorType_FLOAT32, {}}, dilation_factor);
+ {TensorType_FLOAT32, {}}, Padding_VALID, dilation_factor);
// The image matrix is:
// | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@@ -164,6 +197,41 @@ TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ DepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -188,13 +256,20 @@ class QuantizedDepthwiseConvolutionOpModel
}
};
+class QuantizedDepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// In this test we set the input and output scales so that the results match
// exactly the 'non-quantized' version.
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
QuantizedDepthwiseConvolutionOpModel m(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
- {TensorType_UINT8, {}, -127, 128});
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -224,15 +299,16 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
}));
}
-TEST(QuantizedDepthwiseConvolutionOpTest,
- SimpleTestQuantizedFilterMultiplierGreaterThan1) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest,
+ SimpleTestQuantizedFilterMultiplierGreaterThan1) {
QuantizedDepthwiseConvolutionOpModel quant_op(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
- {TensorType_UINT8, {}, -127, 128});
- DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
+ DepthwiseConvolutionOpModel float_op(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
std::initializer_list<float> input = {
1, 2, 7, 8, // column 1
@@ -261,7 +337,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
const int depth = 1;
const int image_width = 9;
const int image_height = 9;
@@ -270,6 +346,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
const int filter_count = 1;
const int dilation_factor = 3;
QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
{TensorType_UINT8,
{image_batch_count, image_height, image_width, depth},
0,
@@ -278,7 +355,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
{depth, filter_size, filter_size, filter_count},
0,
255},
- {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+ {TensorType_UINT8, {}, 0, 255}, Padding_VALID, dilation_factor);
// The image matrix is:
// | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@@ -319,6 +396,55 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
+INSTANTIATE_TEST_CASE_P(
+ QuantizedDepthwiseConvolutionOpTest, QuantizedDepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
index 844ee6a53d..7600b26f5c 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -28,23 +29,29 @@ namespace tflite {
namespace {
// Runs the DepthwiseConv and compares against the reference implementation.
-template <FusedActivationFunctionType Ac>
void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
- int stride, int pad_width, int pad_height,
- int depth_multiplier, const Dims<4>& output_dims) {
+ int stride, int dilation_width_factor,
+ int dilation_height_factor, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max,
+ const Dims<4>& output_dims) {
const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
std::vector<float> output_data(output_buffer_size);
std::vector<float> reference_output_data(output_buffer_size);
- reference_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
- filter_dims, bias_data, bias_dims, stride,
- pad_width, pad_height, depth_multiplier,
- reference_output_data.data(), output_dims);
- optimized_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
- filter_dims, bias_data, bias_dims, stride,
- pad_width, pad_height, depth_multiplier,
- output_data.data(), output_dims);
+ reference_ops::DepthwiseConv(
+ input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride, stride, dilation_width_factor, dilation_height_factor, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, reference_output_data.data(), output_dims);
+ optimized_ops::DepthwiseConv(
+ input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride, stride, dilation_width_factor, dilation_height_factor, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data.data(), output_dims);
+
double sum_abs_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < output_buffer_size; i++) {
@@ -59,27 +66,6 @@ void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-void TestOneDepthwiseConv(FusedActivationFunctionType Ac,
- const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride, int pad_width, int pad_height,
- int depth_multiplier, const Dims<4>& output_dims) {
-#define TOCO_HANDLE_CASE(AC_TYPE) \
- if (AC_TYPE == Ac) { \
- TestOneDepthwiseConv<AC_TYPE>(input_data, input_dims, filter_data, \
- filter_dims, bias_data, bias_dims, stride, \
- pad_width, pad_height, depth_multiplier, \
- output_dims); \
- return; \
- }
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
-#undef TOCO_HANDLE_CASE
-}
-
// This function picks some random DepthwiseConv params, which may or may not
// be legal. If they're not legal, it returns false. If they're legal,
// it runs the DepthwiseConv test and returns true. This allows the caller
@@ -99,6 +85,16 @@ bool TryTestOneDepthwiseConv() {
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
const int output_depth = input_depth * depth_multiplier;
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ float output_activation_min, output_activation_max;
+ FusedActivationFunctionType ac =
+ RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone,
+ FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu1,
+ FusedActivationFunctionType::kRelu6}));
+ GetActivationMinMax(ac, &output_activation_min, &output_activation_max);
// The optimized DepthwiseConv implementation currently uses a fixed-size
// accumulator buffer on the stack, with that size. This currently means
// that it does not support larger output depths. It CHECK's for it,
@@ -109,10 +105,6 @@ bool TryTestOneDepthwiseConv() {
if (output_depth > kMaxSupportedOutputDepth) {
return false;
}
- const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
- {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
- FusedActivationFunctionType::kRelu6,
- FusedActivationFunctionType::kRelu1}));
Dims<4> input_dims_inference =
MakeDimsForInference(input_depth, input_width, input_height, batch);
Dims<4> output_dims_inference;
@@ -120,7 +112,8 @@ bool TryTestOneDepthwiseConv() {
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
- filter_height, stride, padding_type,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
&output_dims_inference, &pad_width, &pad_height)) {
return false;
}
@@ -140,10 +133,12 @@ bool TryTestOneDepthwiseConv() {
FillRandom(&input_data, -input_amplitude, input_amplitude);
FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
- TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
+ TestOneDepthwiseConv(input_data.data(), input_dims_inference,
filter_data.data(), filter_dims_inference,
- bias_data.data(), bias_dims_inference, stride, pad_width,
- pad_height, depth_multiplier, output_dims_inference);
+ bias_data.data(), bias_dims_inference, stride,
+ dilation_width_factor, dilation_height_factor, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_dims_inference);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
index 2c0fc8433e..312d048b2d 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -199,6 +199,7 @@ void TestOneDepthwiseConv(
bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
int input_height, int filter_width, int filter_height,
int depth_multiplier, int stride,
+ int dilation_width_factor, int dilation_height_factor,
PaddingType padding_type) {
const int output_depth = input_depth * depth_multiplier;
// The optimized DepthwiseConv implementation currently uses a fixed-size
@@ -231,7 +232,8 @@ bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
Dims<4> output_dims_inference;
int pad_width, pad_height;
if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
- filter_height, stride, padding_type,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
&output_dims_inference, &pad_width, &pad_height)) {
return false;
}
@@ -274,12 +276,15 @@ bool TryTestOneDepthwiseConv() {
const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
- stride, padding_type);
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
}
// Tests parameters for the 3x3 filter kernel.
@@ -292,6 +297,9 @@ bool TryTestOneDepthwiseConv3x3Filter() {
const int filter_height = 3;
const int depth_multiplier = 1;
const int stride = UniformRandomInt(1, 2);
+ // We don't support dilations in the 3x3 filter.
+ const int dilation_width_factor = 1;
+ const int dilation_height_factor = 1;
// Although the kernel supports only kValid padding, we test that kSame
// is using the correct code path.
const auto padding_type =
@@ -299,7 +307,8 @@ bool TryTestOneDepthwiseConv3x3Filter() {
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
- stride, padding_type);
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
}
void TestOneDepthwiseConv() {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index f2d1319801..f0bea7fa1d 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -761,7 +761,8 @@ struct FloatDepthwiseConvKernel<true, 4, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
const float* input_data, int pad_width,
int depth_multiplier, int filter_width,
const float* filter_data,
@@ -835,10 +836,10 @@ void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
inline void FloatDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const float* input_data,
- int pad_width, int depth_multiplier, int filter_width,
- const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
- int output_depth, float* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const float* input_data, int pad_width, int depth_multiplier,
+ int filter_width, const float* filter_data, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, float* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -860,6 +861,7 @@ inline void FloatDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -869,14 +871,17 @@ inline void FloatDepthwiseConvAccumRowGeneric(
const float* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
float* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const float* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -921,14 +926,14 @@ inline void DepthwiseConv(
const int depth_multiplier = params.depth_multiplier;
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
- // be implemented.
- TFLITE_DCHECK_EQ(params.dilation_width_factor, 1);
- TFLITE_DCHECK_EQ(params.dilation_height_factor, 1);
+ const bool has_dilation = (params.dilation_width_factor != 1) ||
+ (params.dilation_height_factor != 1);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
@@ -961,7 +966,7 @@ inline void DepthwiseConv(
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && !has_dilation) { \
row_accum_func = \
FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -1014,9 +1019,13 @@ inline void DepthwiseConv(
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1032,9 +1041,9 @@ inline void DepthwiseConv(
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
- stride_width, input_depth, input_width,
+ stride_width, dilation_width_factor, input_depth, input_width,
input_data + in_y * input_height_stride + b * input_batch_stride,
pad_width, depth_multiplier, filter_width,
filter_data + filter_y * filter_height_stride, out_x_buffer_start,
@@ -1096,11 +1105,6 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
- // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
- // be implemented.
- TFLITE_DCHECK_EQ(dilation_width_factor, 1);
- TFLITE_DCHECK_EQ(dilation_height_factor, 1);
-
tflite::DepthwiseParams op_params;
// Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index ccb9d1654f..494cf70504 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1466,11 +1466,14 @@ struct QuantizedDepthwiseConvKernel<false, 12, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void QuantizedDepthwiseConvAccumRow(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset,
+ int pad_width, int depth_multiplier,
+ int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth,
+ int32* acc_buffer) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
#endif
@@ -1537,10 +1540,11 @@ void QuantizedDepthwiseConvAccumRow(
// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
inline void QuantizedDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset, int pad_width,
+ int depth_multiplier, int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, int32* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -1562,6 +1566,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -1571,14 +1576,17 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
const uint8* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
int32* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const uint8* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -1688,15 +1696,11 @@ inline void DepthwiseConv(
const int32 output_offset = params.output_offset;
const int32 output_multiplier = params.output_multiplier;
const int output_shift = params.output_shift;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-
- // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
- // be implemented.
- TFLITE_DCHECK_EQ(params.dilation_width_factor, 1);
- TFLITE_DCHECK_EQ(params.dilation_height_factor, 1);
-
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
@@ -1714,14 +1718,18 @@ inline void DepthwiseConv(
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ const bool has_dilation =
+ (dilation_width_factor != 1) || (dilation_height_factor != 1);
+
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro.
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
// Call kernel optimized for depthwise convolutions using 3x3 filters if
// parameters are supported.
- if (Fast3x3FilterKernelSupported(
- input_shape, filter_shape, stride_width, stride_height, pad_width,
- pad_height, depth_multiplier, output_shape, output_shift)) {
+ if (Fast3x3FilterKernelSupported(input_shape, filter_shape, stride_width,
+ stride_height, has_dilation, pad_width,
+ pad_height, depth_multiplier, output_shape,
+ output_shift)) {
DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape,
filter_data, bias_shape, bias_data, output_shape,
output_data);
@@ -1748,7 +1756,7 @@ inline void DepthwiseConv(
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && !has_dilation) { \
row_accum_func = \
QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -1808,9 +1816,13 @@ inline void DepthwiseConv(
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1826,9 +1838,9 @@ inline void DepthwiseConv(
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
- stride_width, input_depth, input_width,
+ stride_width, dilation_width_factor, input_depth, input_width,
input_data + in_y * input_height_stride + b * input_batch_stride,
input_offset, pad_width, depth_multiplier, filter_width,
filter_data + filter_y * filter_height_stride, filter_offset,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 9fed53cafb..5087227182 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -3176,8 +3176,8 @@ inline void DepthwiseConvHandlePadding(const uint8* input_data,
inline bool Fast3x3FilterKernelSupported(
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
- int32 stride_width, int32 stride_height, int32 pad_width, int32 pad_height,
- int32 depth_multiplier, const RuntimeShape& output_shape,
+ int32 stride_width, int32 stride_height, bool has_dilation, int32 pad_width,
+ int32 pad_height, int32 depth_multiplier, const RuntimeShape& output_shape,
int32 output_shift) {
const int32 input_height = input_shape.Dims(1);
const int32 input_width = input_shape.Dims(2);
@@ -3193,7 +3193,7 @@ inline bool Fast3x3FilterKernelSupported(
(stride_height == 1 || stride_height == 2) &&
(stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
(pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
- (input_depth % 8) == 0 && (output_shift > 0);
+ (input_depth % 8) == 0 && (output_shift > 0) && !has_dilation;
if (!supported) {
return false;
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc
index 9b1fd9b344..5ae4b193d0 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc
@@ -43,17 +43,21 @@ Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) {
// this is a copied from an internal function in propagate_fixed_sizes.cc
bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
- int filter_height, int stride, PaddingType padding_type,
+ int filter_height, int stride, int dilation_width_factor,
+ int dilation_height_factor, PaddingType padding_type,
Dims<4>* output_dims, int* pad_width, int* pad_height) {
const int input_width = ArraySize(input_dims, 1);
const int input_height = ArraySize(input_dims, 2);
const int batch = ArraySize(input_dims, 3);
+ int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1;
+ int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1;
+
int output_height = 0;
int output_width = 0;
if (padding_type == PaddingType::kValid) {
- output_height = (input_height + stride - filter_height) / stride;
- output_width = (input_width + stride - filter_width) / stride;
+ output_height = (input_height + stride - dilated_filter_height) / stride;
+ output_width = (input_width + stride - dilated_filter_width) / stride;
} else if (padding_type == PaddingType::kSame) {
output_height = (input_height + stride - 1) / stride;
output_width = (input_width + stride - 1) / stride;
@@ -65,9 +69,13 @@ bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
return false;
}
- *pad_height =
- ((output_height - 1) * stride + filter_height - input_height) / 2;
- *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2;
+ *pad_height = std::max(
+ 0, ((output_height - 1) * stride + dilated_filter_height - input_height) /
+ 2);
+ *pad_width = std::max(
+ 0,
+ ((output_width - 1) * stride + dilated_filter_width - input_width) / 2);
+
*output_dims =
MakeDimsForInference(output_depth, output_width, output_height, batch);
return true;
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h
index 26078cef49..cb6d8b147c 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.h
@@ -31,7 +31,8 @@ Dims<4> MakeDimsForInference(int depth, int width, int height, int batch);
// Computes output and padding dimensions.
bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
- int filter_height, int stride, PaddingType padding_type,
+ int filter_height, int stride, int dilation_width_factor,
+ int dilation_height_factor, PaddingType padding_type,
Dims<4>* output_dims, int* pad_width, int* pad_height);
// Returns a mt19937 random engine.