diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-10 16:09:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 16:12:07 -0700 |
commit | 874cf8e1d332175c8a90d7512f8385e98e2a7377 (patch) | |
tree | 129149c6c73e828246b7202dccdfb4d9ce583de3 /tensorflow/contrib/lite/kernels/batch_to_space_nd.cc | |
parent | 66b6dda1b77cbf075e94009718446511fa13dd41 (diff) |
Enable support for crops in BatchToSpaceNd
PiperOrigin-RevId: 196186750
Diffstat (limited to 'tensorflow/contrib/lite/kernels/batch_to_space_nd.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/batch_to_space_nd.cc | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 90edf4f9e3..bd4057556c 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -66,12 +66,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - // TODO(ycling): Add crops as part of calculation. Remove check for a crops - // containing all zeroes. - TF_LITE_ENSURE_EQ(context, crops[0], 0); - TF_LITE_ENSURE_EQ(context, crops[1], 0); - TF_LITE_ENSURE_EQ(context, crops[2], 0); - TF_LITE_ENSURE_EQ(context, crops[3], 0); + TF_LITE_ENSURE(context, crops[0] >= 0); + TF_LITE_ENSURE(context, crops[1] >= 0); + TF_LITE_ENSURE(context, crops[2] >= 0); + TF_LITE_ENSURE(context, crops[3] >= 0); // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, @@ -79,8 +77,16 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, const int output_batch_size = input_size->data[0] / (block_shape[0] * block_shape[1]); - const int output_height = input_size->data[1] * block_shape[0]; - const int output_width = input_size->data[2] * block_shape[1]; + + const int crops_top = crops[0]; + const int crops_bottom = crops[1]; + const int crops_left = crops[2]; + const int crops_right = crops[3]; + const int output_height = + input_size->data[1] * block_shape[0] - crops_top - crops_bottom; + const int output_width = + input_size->data[2] * block_shape[1] - crops_left - crops_right; + const int output_channel_size = input_size->data[3]; TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); |