diff options
-rw-r--r-- | tensorflow/contrib/lite/build_def.bzl | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc | 5 |
3 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 30bb604d17..612813caee 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -201,7 +201,7 @@ def generated_test_models(): "concat", "constant", "control_dep", - # "conv", + "conv", "depthwiseconv", "div", "equal", diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index ee42e5cdc8..747c8a62c0 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). data->need_im2col = (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); + params->dilation_width_factor != 1 || + params->dilation_height_factor != 1 || filter_width != 1 || + filter_height != 1); // If we're using the optimized multithreaded EigenTensor implementation of // convolution, it expects the filter weights to be transposed compared to // the normal TF Lite buffer format. Typical TF Lite weights are diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 076415ece8..8ca2cd66ac 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -46,8 +46,9 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1) { - // 1x1 unstrided conv does not need an im2col array. + conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 && + conv_op->dilation_height_factor == 1) { + // 1x1 unstrided undilated conv does not need an im2col array. return false; } |