aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-10 16:09:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 16:12:07 -0700
commit874cf8e1d332175c8a90d7512f8385e98e2a7377 (patch)
tree129149c6c73e828246b7202dccdfb4d9ce583de3 /tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
parent66b6dda1b77cbf075e94009718446511fa13dd41 (diff)
Enable support for crops in BatchToSpaceNd
PiperOrigin-RevId: 196186750
Diffstat (limited to 'tensorflow/contrib/lite/kernels/batch_to_space_nd.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc22
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 90edf4f9e3..bd4057556c 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -66,12 +66,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
kSpatialDimensionNum);
- // TODO(ycling): Add crops as part of calculation. Remove check for a crops
- // containing all zeroes.
- TF_LITE_ENSURE_EQ(context, crops[0], 0);
- TF_LITE_ENSURE_EQ(context, crops[1], 0);
- TF_LITE_ENSURE_EQ(context, crops[2], 0);
- TF_LITE_ENSURE_EQ(context, crops[3], 0);
+ TF_LITE_ENSURE(context, crops[0] >= 0);
+ TF_LITE_ENSURE(context, crops[1] >= 0);
+ TF_LITE_ENSURE(context, crops[2] >= 0);
+ TF_LITE_ENSURE(context, crops[3] >= 0);
// Number of batch must be multiple of (block_shape[0] * block_shape[1]).
TF_LITE_ENSURE_EQ(context,
@@ -79,8 +77,16 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const int output_batch_size =
input_size->data[0] / (block_shape[0] * block_shape[1]);
- const int output_height = input_size->data[1] * block_shape[0];
- const int output_width = input_size->data[2] * block_shape[1];
+
+ const int crops_top = crops[0];
+ const int crops_bottom = crops[1];
+ const int crops_left = crops[2];
+ const int crops_right = crops[3];
+ const int output_height =
+ input_size->data[1] * block_shape[0] - crops_top - crops_bottom;
+ const int output_width =
+ input_size->data[2] * block_shape[1] - crops_left - crops_right;
+
const int output_channel_size = input_size->data[3];
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);