aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc139
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h2
3 files changed, 149 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 51989f541f..3ed0cdb131 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -249,6 +249,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
+ int channels_in = filter->dims->data[3];
int channels_out = filter->dims->data[0];
int width = input->dims->data[2];
int height = input->dims->data[1];
@@ -372,12 +373,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->scaling_factors_id;
TfLiteTensor* scaling_factors =
GetTemporary(context, node, data->scaling_factors_index);
- scaling_factors->type = kTfLiteInt32;
+ scaling_factors->type = kTfLiteFloat32;
scaling_factors->allocation_type = kTfLiteArenaRw;
TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
// Only one scale factor per batch is typically necessary. See optimized
- // implementation for why we need to allocate for height elements here.
- scaling_factors_size->data[0] = height;
+ // implementation for why we need to allocate for the height of the inputs
+ // flattened to 2D.
+ scaling_factors_size->data[0] = NumElements(input) / channels_in;
if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
scaling_factors_size));
@@ -549,7 +551,10 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
scaling_factors_ptr[b] *= filter->params.scale;
}
- int8_t* im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ int8_t* im2col_ptr = nullptr;
+ if (im2col != nullptr) {
+ im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ }
int8_t* filter_ptr = reinterpret_cast<int8_t*>(filter->data.uint8);
switch (kernel_type) {
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index a4b9fb1a0b..411615aa62 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -177,6 +177,69 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
}));
}
+TEST_P(ConvolutionOpTest, PointwiseFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {1, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ // First batch
+ 1.5, 1.5, 1.5, 1.5, // row = 1
+ 3., 3., 3., 3., // row = 2
+ // Second batch
+ 1.5, 3., 4.5, 6., // row = 1
+ 1.5, 3., 4.5, 6., // row = 2
+ }));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {2, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ }));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -769,6 +832,82 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannels) {
0.16)));
}
+TEST_P(ConvolutionOpTest, PointwiseHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ // Example: we get 3.03156 instead of 3.
+ //
+ // Second batch:
+ // 0.5 0.5 1 1 1.5 1.5 2 2 -> 32 32 64 64 95 95 127 127 with scale factor
+ // 127/2. We care about the two 64's.
+ //
+ // Filter:
+ // 64 127 with scale factor of 127/2.
+ //
+ // (64 * 64 + 64 * 127) * (2/127)^2 gives us the expected result.
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 1.5, 1.5, 1.5, // first batch, row = 1
+ 3., 3., 3., 3., // first batch, row = 2
+ 1.5, 3., 4.5, 6., // second batch, row = 1
+ 1.5, 3., 4.5, 6., // second batch, row = 2
+ },
+ 0.0316)));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {2, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ },
+ 0.0474)));
+}
+
INSTANTIATE_TEST_CASE_P(
ConvolutionOpTest, ConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 85e631b852..70adffda3b 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -1948,7 +1948,7 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
const int filter_width = ArraySize(filter_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
- const int8* gemm_input_data = nullptr;
+ const int8_t* gemm_input_data = nullptr;
int num_input;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;