diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/test_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/test_util.cc | 40 |
1 files changed, 9 insertions, 31 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc index 5ae4b193d0..75d568ae3a 100644 --- a/tensorflow/contrib/lite/kernels/internal/test_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc @@ -19,36 +19,15 @@ limitations under the License. namespace tflite { -Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) { - Dims<4> result; - int cum_prod = 1; - - result.sizes[0] = depth; - result.strides[0] = cum_prod; - cum_prod *= result.sizes[0]; - - result.sizes[1] = width; - result.strides[1] = cum_prod; - cum_prod *= result.sizes[1]; - - result.sizes[2] = height; - result.strides[2] = cum_prod; - cum_prod *= result.sizes[2]; - - result.sizes[3] = batch; - result.strides[3] = cum_prod; - - return result; -} - // this is a copied from an internal function in propagate_fixed_sizes.cc -bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, - int filter_height, int stride, int dilation_width_factor, - int dilation_height_factor, PaddingType padding_type, - Dims<4>* output_dims, int* pad_width, int* pad_height) { - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int batch = ArraySize(input_dims, 3); +bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth, + int filter_width, int filter_height, int stride, + int dilation_width_factor, int dilation_height_factor, + PaddingType padding_type, RuntimeShape* output_shape, + int* pad_width, int* pad_height) { + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int batch = input_shape.Dims(0); int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1; int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1; @@ -76,8 +55,7 @@ bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, 0, ((output_width - 1) * stride + dilated_filter_width - input_width) / 2); - *output_dims = - MakeDimsForInference(output_depth, output_width, output_height, batch); + output_shape->BuildFrom({batch, output_height, output_width, output_depth}); return true; } |