diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc | 74 |
1 files changed, 37 insertions, 37 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc index 7600b26f5c..41862a21a6 100644 --- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc @@ -29,28 +29,20 @@ namespace tflite { namespace { // Runs the DepthwiseConv and compares against the reference implementation. -void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride, int dilation_width_factor, - int dilation_height_factor, int pad_width, - int pad_height, int depth_multiplier, - float output_activation_min, - float output_activation_max, - const Dims<4>& output_dims) { - const int output_buffer_size = RequiredBufferSizeForDims(output_dims); +void TestOneDepthwiseConv( + const DepthwiseParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& filter_shape, + const float* filter_data, const RuntimeShape& bias_shape, + const float* bias_data, const RuntimeShape& output_shape) { + const int output_buffer_size = output_shape.FlatSize(); std::vector<float> output_data(output_buffer_size); std::vector<float> reference_output_data(output_buffer_size); - reference_ops::DepthwiseConv( - input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, - stride, stride, dilation_width_factor, dilation_height_factor, pad_width, - pad_height, depth_multiplier, output_activation_min, - output_activation_max, reference_output_data.data(), output_dims); - optimized_ops::DepthwiseConv( - input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, - stride, stride, dilation_width_factor, dilation_height_factor, pad_width, - pad_height, depth_multiplier, output_activation_min, - output_activation_max, output_data.data(), output_dims); + reference_ops::DepthwiseConv(params, input_shape, input_data, filter_shape, + filter_data, bias_shape, bias_data, output_shape, + reference_output_data.data()); + optimized_ops::DepthwiseConv(params, input_shape, input_data, filter_shape, + filter_data, bias_shape, bias_data, output_shape, + output_data.data()); double sum_abs_diff = 0; float max_abs_val = 0; @@ -105,24 +97,23 @@ bool TryTestOneDepthwiseConv() { if (output_depth > kMaxSupportedOutputDepth) { return false; } - Dims<4> input_dims_inference = - MakeDimsForInference(input_depth, input_width, input_height, batch); - Dims<4> output_dims_inference; + RuntimeShape input_shape_inference( + {batch, input_height, input_width, input_depth}); + RuntimeShape output_shape_inference; int pad_width, pad_height; const auto padding_type = UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; - if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width, filter_height, stride, dilation_width_factor, dilation_height_factor, padding_type, - &output_dims_inference, &pad_width, &pad_height)) { + &output_shape_inference, &pad_width, &pad_height)) { return false; } - Dims<4> filter_dims_inference = - MakeDimsForInference(output_depth, filter_width, filter_height, 1); - Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); - const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); - const int filter_buffer_size = - RequiredBufferSizeForDims(filter_dims_inference); + RuntimeShape filter_shape_inference( + {1, filter_height, filter_width, output_depth}); + RuntimeShape bias_shape_inference({1, 1, 1, output_depth}); + const int input_buffer_size = input_shape_inference.FlatSize(); + const int filter_buffer_size = filter_shape_inference.FlatSize(); std::vector<float> input_data(input_buffer_size); std::vector<float> filter_data(filter_buffer_size); std::vector<float> bias_data(output_depth); @@ -133,12 +124,21 @@ bool TryTestOneDepthwiseConv() { FillRandom(&input_data, -input_amplitude, input_amplitude); FillRandom(&filter_data, -filter_amplitude, filter_amplitude); FillRandom(&bias_data, -bias_amplitude, bias_amplitude); - TestOneDepthwiseConv(input_data.data(), input_dims_inference, - filter_data.data(), filter_dims_inference, - bias_data.data(), bias_dims_inference, stride, - dilation_width_factor, dilation_height_factor, pad_width, - pad_height, depth_multiplier, output_activation_min, - output_activation_max, output_dims_inference); + DepthwiseParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = pad_width; + op_params.padding_values.height = pad_height; + op_params.stride_width = stride; + op_params.stride_height = stride; + op_params.dilation_width_factor = dilation_width_factor; + op_params.dilation_height_factor = dilation_height_factor; + op_params.depth_multiplier = depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + TestOneDepthwiseConv(op_params, input_shape_inference, input_data.data(), + filter_shape_inference, filter_data.data(), + bias_shape_inference, bias_data.data(), + output_shape_inference); return true; } |