aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-01-29 15:06:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 15:10:40 -0800
commitaaaeef5bbc4698ea48c2476ad5c84a94712c8d2f (patch)
tree6a40ea0692cae1e71b4fa8189fbde711a3569273 /tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
parenta807755db184c2d2c3b9bcca65457e9915508650 (diff)
Make TFLite SpaceToBatchND op have parity with TF SpaceToBatchND op.
PiperOrigin-RevId: 183734695
Diffstat (limited to 'tensorflow/contrib/lite/kernels/space_to_batch_nd.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc95
1 files changed, 50 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 2e22d0db56..e2e1873f77 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -33,17 +33,16 @@ enum KernelType {
kGenericOptimized,
};
-// Inputs specified in the 2nd tensor (block_shape) and 3rd tensor (paddings)
-// are ignored. Only use the `block_shape` and `paddings` specified in params.
-// TODO(nupurgarg): Support inputs as tensors in SpaceToBatchND.
struct SpaceToBatchNDContext {
SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLiteSpaceToBatchNDParams*>(node->builtin_data);
input = GetInput(context, node, 0);
+ block_shape = GetInput(context, node, 1);
+ paddings = GetInput(context, node, 2);
output = GetOutput(context, node, 0);
}
- TfLiteSpaceToBatchNDParams* params;
TfLiteTensor* input;
+ TfLiteTensor* block_shape;
+ TfLiteTensor* paddings;
TfLiteTensor* output;
};
@@ -51,32 +50,29 @@ struct SpaceToBatchNDContext {
// The 4D array need to have exactly 2 spatial dimensions.
// TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
const int kInputDimensionNum = 4;
-const int kOutputDimensionNum = 4;
+const int kBlockSizeDimensionNum = 1;
const int kSpatialDimensionNum = 2;
-const int kPaddingDimensionNum = 4;
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ SpaceToBatchNDContext* op_context) {
+ TfLiteIntArray* input_size = op_context->input->dims;
+ const int32* block_shape = GetTensorData<int32>(op_context->block_shape);
+ const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
- SpaceToBatchNDContext op_context(context, node);
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
- kInputDimensionNum);
- TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions,
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
+ kBlockSizeDimensionNum);
+ TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
+ kSpatialDimensionNum);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings),
kSpatialDimensionNum);
- TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
-
- const TfLiteIntArray* input_size = op_context.input->dims;
- const int* block_shape = op_context.params->block_shape;
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum);
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
// Ensures the input height and width (with padding) is a multiple of block
// shape height and width.
for (int dim = 0; dim < kSpatialDimensionNum; ++dim) {
- int final_dim_size =
- (input_size->data[dim + 1] + op_context.params->before_paddings[dim] +
- op_context.params->after_paddings[dim]);
+ int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] +
+ paddings_data[dim * 2 + 1]);
TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0);
output_size->data[dim + 1] = final_dim_size / block_shape[dim];
}
@@ -88,33 +84,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size->data[0] = output_batch_size;
output_size->data[3] = output_channel_size;
- return context->ResizeTensor(context, op_context.output, output_size);
+ return context->ResizeTensor(context, op_context->output, output_size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ SpaceToBatchNDContext op_context(context, node);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
+ kInputDimensionNum);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+
+ if (!IsConstantTensor(op_context.block_shape) ||
+ !IsConstantTensor(op_context.paddings)) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, &op_context);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
SpaceToBatchNDContext op_context(context, node);
- int block_shape_dims_array[1] = {kSpatialDimensionNum};
- Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1);
-
- // Initialize padding array in the format accepted by the kernel code.
- // TODO(nupurgarg): Make kernel code accept padding array format that is
- // consistent with Pad operation (i.e. before_paddings and after_paddings).
- TfLiteIntArray* padding_data = TfLiteIntArrayCreate(kPaddingDimensionNum);
- padding_data->data[0] = op_context.params->before_paddings[0];
- padding_data->data[1] = op_context.params->after_paddings[0];
- padding_data->data[2] = op_context.params->before_paddings[1];
- padding_data->data[3] = op_context.params->after_paddings[1];
- int padding_dims_array[1] = {kPaddingDimensionNum};
- Dims<4> padding_dims = GetTensorDims(padding_dims_array, 1);
-
-#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \
- type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
- op_context.params->block_shape, block_shape_dims, \
- padding_data->data, padding_dims, \
- GetTensorData<scalar>(op_context.output), \
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
+ }
+
+#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \
+ type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
+ GetTensorDims(op_context.input), \
+ GetTensorData<int32_t>(op_context.block_shape), \
+ GetTensorDims(op_context.block_shape), \
+ GetTensorData<int32_t>(op_context.paddings), \
+ GetTensorDims(op_context.paddings), \
+ GetTensorData<scalar>(op_context.output), \
GetTensorDims(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
@@ -151,8 +158,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
#undef TF_LITE_SPACE_TO_BATCH_ND
-
- TfLiteIntArrayFree(padding_data);
return kTfLiteOk;
}