aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-11 14:39:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 14:41:53 -0700
commit21fb4eeb3e09fb0dea1dd12b0fff7a7bf0a33643 (patch)
tree8d369725fa9f3af9f9178b6ccca9a9bd45b966e7 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parente5201672aa664cf39725f4a52b9774d2bae43ba3 (diff)
Adding support for batch_to_space_nd op with crops.
PiperOrigin-RevId: 192511036
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc16
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index a648b770f8..9191e69662 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1060,17 +1060,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
}
QCHECK(crops_array.data_type == ArrayDataType::kInt32);
const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
- // We don't support crops now.
- QCHECK_EQ(crops_data[0], 0);
- QCHECK_EQ(crops_data[1], 0);
- QCHECK_EQ(crops_data[2], 0);
- QCHECK_EQ(crops_data[3], 0);
-
+ const int crops_top = crops_data[0];
+ const int crops_bottom = crops_data[1];
+ const int crops_left = crops_data[2];
+ const int crops_right = crops_data[3];
+ const int output_height =
+ input_height * block_height - crops_top - crops_bottom;
+ const int output_width = input_width * block_width - crops_left - crops_right;
QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
- int output_height = input_height * block_height;
- int output_width = input_width * block_width;
-
model->GetArray(op->outputs[0])
.copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
output_height, output_width, input_shape.dims(3)}));