diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-01-29 10:31:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-29 10:37:43 -0800 |
commit | 730071d0dca35a9e08f3bdc49661ae34d109da74 (patch) | |
tree | a8e25ab416d0501ed5cf3236832388b11287a267 /tensorflow/contrib/lite/kernels/batch_to_space_nd.cc | |
parent | 1f26c65254268730b7409f517d1ed1b554d01e50 (diff) |
Make TFLite BatchToSpaceND op have parity with TF BatchToSpaceND op.
PiperOrigin-RevId: 183687487
Diffstat (limited to 'tensorflow/contrib/lite/kernels/batch_to_space_nd.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/batch_to_space_nd.cc | 69 |
1 files changed, 45 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 0eed680fdc..d84a77039b 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -35,12 +35,14 @@ enum KernelType { struct BatchToSpaceNDContext { BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast<TfLiteBatchToSpaceNDParams*>(node->builtin_data); input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + crops = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteBatchToSpaceNDParams* params; TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* crops; TfLiteTensor* output; }; @@ -48,24 +50,22 @@ struct BatchToSpaceNDContext { // The 4D array need to have exactly 2 spatial dimensions. // TODO(ycling): Support arbitrary dimension in BatchToSpaceND. const int kInputDimensionNum = 4; -const int kOutputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; const int kSpatialDimensionNum = 2; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. - TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BatchToSpaceNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int* block_shape = GetTensorData<int32>(op_context->block_shape); - BatchToSpaceNDContext op_context(context, node); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), - kInputDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - - const TfLiteIntArray* input_size = op_context.input->dims; - const int* block_shape = op_context.params->block_shape; + // TODO(ycling): Add crops as part of calculation. // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, input_size->data[0] % (block_shape[0] * block_shape[1]), 0); @@ -76,27 +76,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int output_width = input_size->data[2] * block_shape[1]; const int output_channel_size = input_size->data[3]; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); output_size->data[0] = output_batch_size; output_size->data[1] = output_height; output_size->data[2] = output_width; output_size->data[3] = output_channel_size; - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.crops)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { BatchToSpaceNDContext op_context(context, node); - int block_shape_dims_array[1] = {kSpatialDimensionNum}; - Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } -#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ - type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \ - GetTensorDims(op_context.input), \ - op_context.params->block_shape, block_shape_dims, \ - GetTensorData<scalar>(op_context.output), \ +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData<int32_t>(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData<scalar>(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: |