aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-01-29 10:31:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 10:37:43 -0800
commit730071d0dca35a9e08f3bdc49661ae34d109da74 (patch)
treea8e25ab416d0501ed5cf3236832388b11287a267 /tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
parent1f26c65254268730b7409f517d1ed1b554d01e50 (diff)
Make TFLite BatchToSpaceND op have parity with TF BatchToSpaceND op.
PiperOrigin-RevId: 183687487
Diffstat (limited to 'tensorflow/contrib/lite/kernels/batch_to_space_nd.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc69
1 files changed, 45 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 0eed680fdc..d84a77039b 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -35,12 +35,14 @@ enum KernelType {
struct BatchToSpaceNDContext {
BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLiteBatchToSpaceNDParams*>(node->builtin_data);
input = GetInput(context, node, 0);
+ block_shape = GetInput(context, node, 1);
+ crops = GetInput(context, node, 2);
output = GetOutput(context, node, 0);
}
- TfLiteBatchToSpaceNDParams* params;
TfLiteTensor* input;
+ TfLiteTensor* block_shape;
+ TfLiteTensor* crops;
TfLiteTensor* output;
};
@@ -48,24 +50,22 @@ struct BatchToSpaceNDContext {
// The 4D array need to have exactly 2 spatial dimensions.
// TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
const int kInputDimensionNum = 4;
-const int kOutputDimensionNum = 4;
+const int kBlockSizeDimensionNum = 1;
const int kSpatialDimensionNum = 2;
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now.
- TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ BatchToSpaceNDContext* op_context) {
+ TfLiteIntArray* input_size = op_context->input->dims;
+ const int* block_shape = GetTensorData<int32>(op_context->block_shape);
- BatchToSpaceNDContext op_context(context, node);
- TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
- kInputDimensionNum);
- TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions,
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
+ kBlockSizeDimensionNum);
+ TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
+ kSpatialDimensionNum);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
kSpatialDimensionNum);
- TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
-
- const TfLiteIntArray* input_size = op_context.input->dims;
- const int* block_shape = op_context.params->block_shape;
+ // TODO(ycling): Add crops as part of calculation.
// Number of batch must be multiple of (block_shape[0] * block_shape[1]).
TF_LITE_ENSURE_EQ(context,
input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
@@ -76,27 +76,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int output_width = input_size->data[2] * block_shape[1];
const int output_channel_size = input_size->data[3];
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum);
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
output_size->data[0] = output_batch_size;
output_size->data[1] = output_height;
output_size->data[2] = output_width;
output_size->data[3] = output_channel_size;
- return context->ResizeTensor(context, op_context.output, output_size);
+ return context->ResizeTensor(context, op_context->output, output_size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ BatchToSpaceNDContext op_context(context, node);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
+ kInputDimensionNum);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+
+ if (!IsConstantTensor(op_context.block_shape) ||
+ !IsConstantTensor(op_context.crops)) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputTensor(context, &op_context);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
BatchToSpaceNDContext op_context(context, node);
- int block_shape_dims_array[1] = {kSpatialDimensionNum};
- Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
+ }
-#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
- type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
- op_context.params->block_shape, block_shape_dims, \
- GetTensorData<scalar>(op_context.output), \
+#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
+ type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
+ GetTensorDims(op_context.input), \
+ GetTensorData<int32_t>(op_context.block_shape), \
+ GetTensorDims(op_context.block_shape), \
+ GetTensorData<scalar>(op_context.output), \
GetTensorDims(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32: