diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-11 14:39:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-11 14:41:53 -0700 |
commit | 21fb4eeb3e09fb0dea1dd12b0fff7a7bf0a33643 (patch) | |
tree | 8d369725fa9f3af9f9178b6ccca9a9bd45b966e7 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | e5201672aa664cf39725f4a52b9774d2bae43ba3 (diff) |
Adding support for batch_to_space_nd op with crops.
PiperOrigin-RevId: 192511036
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index a648b770f8..9191e69662 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1060,17 +1060,15 @@ void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { } QCHECK(crops_array.data_type == ArrayDataType::kInt32); const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data; - // We don't support crops now. - QCHECK_EQ(crops_data[0], 0); - QCHECK_EQ(crops_data[1], 0); - QCHECK_EQ(crops_data[2], 0); - QCHECK_EQ(crops_data[3], 0); - + const int crops_top = crops_data[0]; + const int crops_bottom = crops_data[1]; + const int crops_left = crops_data[2]; + const int crops_right = crops_data[3]; + const int output_height = + input_height * block_height - crops_top - crops_bottom; + const int output_width = input_width * block_width - crops_left - crops_right; QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0); - int output_height = input_height * block_height; - int output_width = input_width * block_width; - model->GetArray(op->outputs[0]) .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width), output_height, output_width, input_shape.dims(3)})); |