diff options
author | Jared Duke <jdduke@google.com> | 2018-09-12 08:42:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 08:49:55 -0700 |
commit | 6995d2b9be0e398f11a17348eb5b4745aee0af0d (patch) | |
tree | 174a2f4074a50ef8d5010687ffc4182124016258 /tensorflow/contrib/lite/kernels | |
parent | 9333978b4b08e4b3fdc7f63ec0873a7e00dcc4b7 (diff) |
Fix convolution bug when input and filter dimensions match
TFLite has an optimized matmul path for cases where the input and
filter tensors have matching width+height. However, this case doesn't
properly account for multiple *batches*. Account for this and add
an appropriate test.
Credit to zgxnet for the bug and proposed fix.
Fixes #21817
PiperOrigin-RevId: 212645329
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv_test.cc | 24 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h | 4 |
2 files changed, 26 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 411615aa62..f7e6f083ed 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -177,6 +177,30 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) { })); } +TEST_P(ConvolutionOpTest, InputAndFilterSameWidthHeight) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {1, 2, 4, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // row = 1 + -1, -1, 1, 1, // row = 2 + }); + m.SetBias({0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 34})); +} + TEST_P(ConvolutionOpTest, PointwiseFloat32) { ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, {TensorType_FLOAT32, {1, 1, 1, 2}}, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 5fb31889fe..59f0e3c927 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -113,8 +113,8 @@ class EigenTensorConvFunctor { filter_width * filter_height * input_depth; Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0); - EigenMatrix output(output_data, 1, filter_count); - ConstEigenMatrix input(input_data, 1, k); + EigenMatrix output(output_data, input_batches, filter_count); + ConstEigenMatrix input(input_data, input_batches, k); ConstEigenMatrix filter(filter_data, k, filter_count); MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input, filter, dim_pair); |