diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/reshape.cc | 69 |
1 files changed, 54 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 3287040695..99ecc16093 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -25,16 +25,11 @@ namespace builtin { namespace reshape { constexpr int kInputTensor = 0; +constexpr int kShapeTensor = 1; constexpr int kOutputTensor = 0; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data); - - // TODO(ahentz): we are often given a tensor with the shape but we only pay - // attention to what the shape specified in 'params'. - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node, + TfLiteIntArray* output_shape) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -47,32 +42,76 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { num_input_elements *= SizeOfDimension(input, i); } - TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); int num_output_elements = 1; int stretch_dim = -1; - for (int i = 0; i < params->num_dimensions; ++i) { - int value = params->shape[i]; + for (int i = 0; i < output_shape->size; ++i) { + int value = output_shape->data[i]; if (value == -1) { TF_LITE_ENSURE_EQ(context, stretch_dim, -1); stretch_dim = i; } else { num_output_elements *= value; - output_size->data[i] = value; } } if (stretch_dim != -1) { - output_size->data[stretch_dim] = num_input_elements / num_output_elements; - num_output_elements *= output_size->data[stretch_dim]; + output_shape->data[stretch_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_shape->data[stretch_dim]; } TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); - return context->ResizeTensor(context, output, output_size); + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context, + TfLiteNode* node) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); + for (int i = 0; i < output_shape->size; ++i) { + output_shape->data[i] = shape->data.i32[i]; + } + return ResizeOutput(context, node, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data); + + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Attempt to use shape tensor if it exists. + if (NumInputs(node) == 2) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + // Check if the shape tensor is valid. + if (shape->dims->size == 1 && shape->type == kTfLiteInt32) { + // Set the output tensor as dynamic if the shape isn't constnat. + if (!IsConstantTensor(shape)) { + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + SetTensorToDynamic(output); + return kTfLiteOk; + } + // Shape is constant. Resize now. + return ResizeOutputWithShapeTensor(context, node); + } + } + // The function is returned above this line if the shape tensor is usable. + // Now fallback to the shape parameter in `TfLiteReshapeParams`. + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions); + for (int i = 0; i < params->num_dimensions; ++i) { + output_shape->data[i] = params->shape[i]; + } + return ResizeOutput(context, node, output_shape); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); return kTfLiteOk; |