diff options
author | 2018-09-24 20:39:41 -0700 | |
---|---|---|
committer | 2018-09-24 20:43:58 -0700 | |
commit | 626fef2af7d4bc49aeeef7ffd195dc30235bcd1e (patch) | |
tree | f81c1a5b95696897957619b5635537c73942b8fe /tensorflow/contrib/lite/kernels/transpose_conv.cc | |
parent | 6ba60e051409a5346c2aab21160c9c311de1cb03 (diff) |
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214377809
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/transpose_conv.cc | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 6f2d98ede8..1c4a5ee91d 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -69,7 +69,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4); - // Currenlty only supports float32. + // Currently only supports float32. const TfLiteType data_type = input->type; TF_LITE_ENSURE(context, data_type == kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, output->type, data_type); @@ -117,19 +117,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently only support float32. switch (input->type) { - case kTfLiteFloat32: + case kTfLiteFloat32: { + tflite::ConvParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = padding_size.width; + op_params.padding_values.height = padding_size.height; + op_params.stride_width = stride_width; + op_params.stride_height = stride_height; + reference_ops::TransposeConv( - GetTensorData<float>(input), GetTensorDims(input), - GetTensorData<float>(weights), GetTensorDims(weights), stride_width, - stride_height, padding_size.width, padding_size.height, - GetTensorData<float>(output), GetTensorDims(output), + op_params, GetTensorShape(input), GetTensorData<float>(input), + GetTensorShape(weights), GetTensorData<float>(weights), + GetTensorShape(output), GetTensorData<float>(output), // Last two args specify im2col which reference_ops ignores. // (Note this does not lead to a performance regression, as the // previous optimized version was just a copy of the reference code.) // TODO(b/110208176): Allocate im2col tensors and switch to // optimized_ops. - GetTensorData<float>(output), GetTensorDims(output)); + GetTensorShape(output), GetTensorData<float>(output)); break; + } default: context->ReportError(context, "Type %d, not currently supported.", input->type); |