aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reshape.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc69
1 files changed, 54 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 3287040695..99ecc16093 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -25,16 +25,11 @@ namespace builtin {
namespace reshape {
constexpr int kInputTensor = 0;
+constexpr int kShapeTensor = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
-
- // TODO(ahentz): we are often given a tensor with the shape but we only pay
- // attention to what the shape specified in 'params'.
- TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
+TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
+ TfLiteIntArray* output_shape) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -47,32 +42,76 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
num_input_elements *= SizeOfDimension(input, i);
}
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions);
int num_output_elements = 1;
int stretch_dim = -1;
- for (int i = 0; i < params->num_dimensions; ++i) {
- int value = params->shape[i];
+ for (int i = 0; i < output_shape->size; ++i) {
+ int value = output_shape->data[i];
if (value == -1) {
TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
stretch_dim = i;
} else {
num_output_elements *= value;
- output_size->data[i] = value;
}
}
if (stretch_dim != -1) {
- output_size->data[stretch_dim] = num_input_elements / num_output_elements;
- num_output_elements *= output_size->data[stretch_dim];
+ output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
+ num_output_elements *= output_shape->data[stretch_dim];
}
TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
- return context->ResizeTensor(context, output, output_size);
+ return context->ResizeTensor(context, output, output_shape);
+}
+
+TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context,
+ TfLiteNode* node) {
+ const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
+ for (int i = 0; i < output_shape->size; ++i) {
+ output_shape->data[i] = shape->data.i32[i];
+ }
+ return ResizeOutput(context, node, output_shape);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
+
+ TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Attempt to use shape tensor if it exists.
+ if (NumInputs(node) == 2) {
+ const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
+ // Check if the shape tensor is valid.
+ if (shape->dims->size == 1 && shape->type == kTfLiteInt32) {
+ // Set the output tensor as dynamic if the shape isn't constnat.
+ if (!IsConstantTensor(shape)) {
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ // Shape is constant. Resize now.
+ return ResizeOutputWithShapeTensor(context, node);
+ }
+ }
+ // The function is returned above this line if the shape tensor is usable.
+ // Now fallback to the shape parameter in `TfLiteReshapeParams`.
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
+ for (int i = 0; i < params->num_dimensions; ++i) {
+ output_shape->data[i] = params->shape[i];
+ }
+ return ResizeOutput(context, node, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node));
+ }
+
memcpy(output->data.raw, input->data.raw, input->bytes);
return kTfLiteOk;