aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-02 11:24:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 11:32:17 -0800
commit224874002f93fec471e401488e23d97d4f36c4fc (patch)
treef2c50bfc0b376bd162d6d84dc1e674bab0a4fc70
parent51792c887abd425693a3c36f16ea221b949f7277 (diff)
Allow ResizeBilinear to resize the output tensor in Prepare(), if the size tensor is const.
PiperOrigin-RevId: 184309687
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc33
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc81
2 files changed, 92 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 9a419af023..4a2101f246 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -36,6 +36,17 @@ constexpr int kInputTensor = 0;
constexpr int kSizeTensor = 1;
constexpr int kOutputTensor = 0;
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input,
+ TfLiteTensor* size, TfLiteTensor* output) {
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = input->dims->data[0];
+ const int32* size_data = GetTensorData<int32>(size);
+ output_size->data[1] = size_data[0];
+ output_size->data[2] = size_data[1];
+ output_size->data[3] = input->dims->data[3];
+ return context->ResizeTensor(context, output, output_size);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -55,9 +66,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// integers.
output->type = kTfLiteFloat32;
- // TODO(ahentz): if the input is constant, we can allocate here.
- output->allocation_type = kTfLiteDynamic;
- return kTfLiteOk;
+ if (!IsConstantTensor(size)) {
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, input, size, output);
}
template <KernelType kernel_type>
@@ -66,15 +79,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* size = GetInput(context, node, kSizeTensor);
- // TODO(ahentz): we only need to do this here if it wasn't done in Eval().
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
- output_size->data[0] = input->dims->data[0];
- const int32* size_data = GetTensorData<int32>(size);
- output_size->data[1] = size_data[0];
- output_size->data[2] = size_data[1];
- output_size->data[3] = input->dims->data[3];
- context->ResizeTensor(context, output, output_size);
- TfLiteTensorRealloc(output->bytes, output);
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeOutputTensor(context, input, size, output));
+ TfLiteTensorRealloc(output->bytes, output);
+ }
if (output->type == kTfLiteFloat32) {
#define TF_LITE_RESIZE_BILINEAR(type) \
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
index 2b1aaf654f..4e03f3820a 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -25,14 +25,24 @@ using ::testing::ElementsAreArray;
class ResizeBilinearOpModel : public SingleOpModel {
public:
- ResizeBilinearOpModel(std::initializer_list<int> input_shape) {
- input_ = AddInput(TensorType_FLOAT32);
- size_ = AddInput(TensorType_INT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ ResizeBilinearOpModel(const TensorData& input,
+ std::initializer_list<int> size_data = {}) {
+ bool const_size = size_data.size() != 0;
+ input_ = AddInput(input);
+ if (const_size) {
+ size_ = AddConstInput(TensorType_INT32, size_data, {2});
+ } else {
+ size_ = AddInput({TensorType_INT32, {2}});
+ }
+ output_ = AddOutput(TensorType_FLOAT32); // Always float.
SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
BuiltinOptions_ResizeBilinearOptions,
CreateResizeBilinearOptions(builder_).Union());
- BuildInterpreter({input_shape, {2}});
+ if (const_size) {
+ BuildInterpreter({GetShape(input_)});
+ } else {
+ BuildInterpreter({GetShape(input_), GetShape(size_)});
+ }
}
void SetInput(std::initializer_list<float> data) {
@@ -49,23 +59,33 @@ class ResizeBilinearOpModel : public SingleOpModel {
};
TEST(ResizeBilinearOpTest, HorizontalResize) {
- ResizeBilinearOpModel m({1, 1, 2, 1});
+ ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
m.SetInput({3, 6});
m.SetSize({1, 3});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+
+ ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
+ const_m.SetInput({3, 6});
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, VerticalResize) {
- ResizeBilinearOpModel m({1, 2, 1, 1});
+ ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
m.SetInput({3, 9});
m.SetSize({3, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+
+ ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
+ const_m.SetInput({3, 9});
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
- ResizeBilinearOpModel m({1, 2, 2, 1});
+ ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
m.SetInput({
3, 6, //
9, 12 //
@@ -77,10 +97,22 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
7, 9, 10, //
9, 11, 12, //
})));
+
+ ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
+ const_m.SetInput({
+ 3, 6, //
+ 9, 12 //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
- ResizeBilinearOpModel m({2, 2, 2, 1});
+ ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
m.SetInput({
3, 6, //
9, 12, //
@@ -97,10 +129,27 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
8, 12, 14, //
10, 14, 16, //
})));
+
+ ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
+ const_m.SetInput({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 10, 16 //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
- ResizeBilinearOpModel m({1, 2, 2, 2});
+ ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
m.SetInput({
3, 4, 6, 10, //
9, 10, 12, 16, //
@@ -112,6 +161,18 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
7, 8, 9, 12, 10, 14, //
9, 10, 11, 14, 12, 16, //
})));
+
+ ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
+ const_m.SetInput({
+ 3, 4, 6, 10, //
+ 9, 10, 12, 16, //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
}
} // namespace