aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/build_def.bzl2
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc5
3 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 30bb604d17..612813caee 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -201,7 +201,7 @@ def generated_test_models():
"concat",
"constant",
"control_dep",
- # "conv",
+ "conv",
"depthwiseconv",
"div",
"equal",
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index ee42e5cdc8..747c8a62c0 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
data->need_im2col =
(params->stride_width != 1 || params->stride_height != 1 ||
- filter_width != 1 || filter_height != 1);
+ params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1 || filter_width != 1 ||
+ filter_height != 1);
// If we're using the optimized multithreaded EigenTensor implementation of
// convolution, it expects the filter weights to be transposed compared to
// the normal TF Lite buffer format. Typical TF Lite weights are
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 076415ece8..8ca2cd66ac 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -46,8 +46,9 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
- conv_op->stride_height == 1) {
- // 1x1 unstrided conv does not need an im2col array.
+ conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 &&
+ conv_op->dilation_height_factor == 1) {
+ // 1x1 unstrided undilated conv does not need an im2col array.
return false;
}