aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-13 15:01:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 15:09:10 -0700
commitfb50c8e9a3cb2ccfac9cf4a847d5841cba80b524 (patch)
tree969d44684c674fad7ef775323f6e150e616890d9
parente8af4e1bb9496c111530e88263fb1b8dac8bdde9 (diff)
Dilated Depthwise Conv reference implementations.
PiperOrigin-RevId: 212884951
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc116
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h20
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h28
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs4
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h38
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py2
-rw-r--r--tensorflow/contrib/lite/toco/model.h5
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc14
13 files changed, 314 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index fa43e6a024..be9d551ee4 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -25,6 +25,9 @@ extern "C" {
// TODO(aselle): Consider using "if this then that" for testing.
+// IMPORTANT: All new members of structs must be added at the end to ensure
+// backwards compatibility.
+
// Possible padding types (for convolutions)
typedef enum {
kTfLitePaddingUnknown = 0,
@@ -71,11 +74,15 @@ typedef struct {
} TfLitePoolParams;
typedef struct {
+ // Parameters for DepthwiseConv version 1 or above.
TfLitePadding padding;
int stride_width;
int stride_height;
int depth_multiplier;
TfLiteFusedActivation activation;
+ // Parameters for DepthwiseConv version 2 or above.
+ int dilation_width_factor;
+ int dilation_height_factor;
} TfLiteDepthwiseConvParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index eef4b6d831..f4d2839b1b 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -216,6 +216,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->depth_multiplier = conv_params->depth_multiplier();
params->activation =
parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 347515f289..3e1ce60113 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -126,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto compute_out_size = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
int out_height =
- compute_out_size(height, filter_height, params->stride_height);
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
- data->padding.height = ComputePadding(params->stride_height, 1, height,
- filter_height, out_height);
+ data->padding.height =
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, out_height);
data->padding.width =
- ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -177,8 +182,19 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, float, float, float*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, float, float, float*,
+ const Dims<4>&);
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -188,7 +204,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(filter), GetTensorDims(filter),
GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_activation_min, output_activation_max,
GetTensorData<float>(output), GetTensorDims(output));
}
@@ -204,9 +221,20 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int32, int32, int, int32,
- int32, uint8*, const Dims<4>&);
- if (kernel_type == kReference) {
+ int, int, int, int, int, int, int, int32, int32, int,
+ int32, int32, uint8*, const Dims<4>&);
+
+ KernelType effective_kernel_type;
+ // TODO(suharshs): Currently only the reference implementation supports
+ // dilations.
+ if ((params->dilation_width_factor != 1) ||
+ (params->dilation_height_factor != 1)) {
+ effective_kernel_type = kReference;
+ } else {
+ effective_kernel_type = kernel_type;
+ }
+
+ if (effective_kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -216,7 +244,8 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
+ params->stride_height, params->dilation_width_factor,
+ params->dilation_height_factor, data->padding.width, data->padding.height,
params->depth_multiplier, output_offset, data->output_multiplier,
data->output_shift, data->output_activation_min,
data->output_activation_max, GetTensorData<uint8_t>(output),
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index c00cafb9fb..2af26ab80a 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -30,7 +30,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
// stride values.
BaseDepthwiseConvolutionOpModel(const TensorData& input,
const TensorData& filter,
- const TensorData& output) {
+ const TensorData& output,
+ int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -56,7 +57,8 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
- ActivationFunctionType_NONE)
+ ActivationFunctionType_NONE,
+ dilation_factor, dilation_factor)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
@@ -110,6 +112,58 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
+TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ DepthwiseConvolutionOpModel m(
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -207,6 +261,64 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
+TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ QuantizedDepthwiseConvolutionOpModel m(
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 7f6eea2d5d..70810ca784 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -1067,6 +1067,26 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
+ // be implemented.
+ TFLITE_DCHECK(dilation_width_factor == 1);
+ TFLITE_DCHECK(dilation_height_factor == 1);
+
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 3fd00c8930..f707279600 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1964,6 +1964,30 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ // TODO(suharshs): Optimized implementation of dilation depthwise is not
+ // supported yet.
+ TFLITE_DCHECK(dilation_width_factor == 1);
+ TFLITE_DCHECK(dilation_height_factor == 1);
+
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 9aabee5000..bb5d590775 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -25,8 +25,9 @@ namespace reference_ops {
inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
@@ -52,8 +53,9 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float total = 0.f;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -81,6 +83,20 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index d57739279f..5e3e8997fc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -30,8 +30,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
int32 output_offset, int32 output_multiplier,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
@@ -58,8 +59,9 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -90,6 +92,24 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index d5da4fcccf..f0db22d581 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -276,11 +276,15 @@ table Pool2DOptions {
}
table DepthwiseConv2DOptions {
+ // Parameters for DepthwiseConv version 1 or above.
padding:Padding;
stride_w:int;
stride_h:int;
depth_multiplier:int;
fused_activation_function:ActivationFunctionType;
+ // Parameters for DepthwiseConv version 2 or above.
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
}
table ConcatEmbeddingsOptions {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 0b9c57480e..8c086a5e67 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -2339,12 +2339,16 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable {
int32_t stride_h;
int32_t depth_multiplier;
ActivationFunctionType fused_activation_function;
+ int32_t dilation_w_factor;
+ int32_t dilation_h_factor;
DepthwiseConv2DOptionsT()
: padding(Padding_SAME),
stride_w(0),
stride_h(0),
depth_multiplier(0),
- fused_activation_function(ActivationFunctionType_NONE) {
+ fused_activation_function(ActivationFunctionType_NONE),
+ dilation_w_factor(1),
+ dilation_h_factor(1) {
}
};
@@ -2355,7 +2359,9 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VT_STRIDE_W = 6,
VT_STRIDE_H = 8,
VT_DEPTH_MULTIPLIER = 10,
- VT_FUSED_ACTIVATION_FUNCTION = 12
+ VT_FUSED_ACTIVATION_FUNCTION = 12,
+ VT_DILATION_W_FACTOR = 14,
+ VT_DILATION_H_FACTOR = 16
};
Padding padding() const {
return static_cast<Padding>(GetField<int8_t>(VT_PADDING, 0));
@@ -2372,6 +2378,12 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ int32_t dilation_w_factor() const {
+ return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
+ }
+ int32_t dilation_h_factor() const {
+ return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING) &&
@@ -2379,6 +2391,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VerifyField<int32_t>(verifier, VT_STRIDE_H) &&
VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) &&
verifier.EndTable();
}
DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2404,6 +2418,12 @@ struct DepthwiseConv2DOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_dilation_w_factor(int32_t dilation_w_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2422,8 +2442,12 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
int32_t stride_w = 0,
int32_t stride_h = 0,
int32_t depth_multiplier = 0,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ int32_t dilation_w_factor = 1,
+ int32_t dilation_h_factor = 1) {
DepthwiseConv2DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
builder_.add_depth_multiplier(depth_multiplier);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
@@ -7064,6 +7088,8 @@ inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const
{ auto _e = stride_h(); _o->stride_h = _e; };
{ auto _e = depth_multiplier(); _o->depth_multiplier = _e; };
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; };
+ { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; };
}
inline flatbuffers::Offset<DepthwiseConv2DOptions> DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7079,13 +7105,17 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
auto _stride_h = _o->stride_h;
auto _depth_multiplier = _o->depth_multiplier;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _dilation_w_factor = _o->dilation_w_factor;
+ auto _dilation_h_factor = _o->dilation_h_factor;
return tflite::CreateDepthwiseConv2DOptions(
_fbb,
_padding,
_stride_w,
_stride_h,
_depth_multiplier,
- _fused_activation_function);
+ _fused_activation_function,
+ _dilation_w_factor,
+ _dilation_h_factor);
}
inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 5d0895c72f..3754b58b23 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -1434,6 +1434,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
"filter_size": [[1, 1], [1, 2], [3, 3]],
"strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
"channel_multiplier": [1, 2],
"rate": [[1, 1]],
"padding": ["SAME", "VALID"],
@@ -1444,6 +1445,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3]],
"filter_size": [[1, 1]],
"strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]],
"channel_multiplier": [2],
"rate": [[2, 2]], # Only [1, 1] is supported
"padding": ["SAME"],
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2e100e37f6..164b70f2df 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -477,6 +477,11 @@ struct DepthwiseConvOperator : Operator {
int stride_height = 0;
int stride_width = 0;
int depth_multiplier = 0;
+ // A dilation_rate of 0 is invalid and this field is an optional attribute.
+ // Thus initializing it to 1 to allow default conv behavior when the
+ // attribute is not present.
+ int dilation_width_factor = 1;
+ int dilation_height_factor = 1;
};
// Depth-to-space transform operator.
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 5486012176..1061e7c7c4 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -107,7 +107,8 @@ class DepthwiseConvolution
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
- op.depth_multiplier, activation_function);
+ op.depth_multiplier, activation_function, op.dilation_width_factor,
+ op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
@@ -118,9 +119,18 @@ class DepthwiseConvolution
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ op->dilation_width_factor = options.dilation_w_factor();
+ op->dilation_height_factor = options.dilation_h_factor();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+ }
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,