aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/transpose_conv.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 20:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 20:43:58 -0700
commit626fef2af7d4bc49aeeef7ffd195dc30235bcd1e (patch)
treef81c1a5b95696897957619b5635537c73942b8fe /tensorflow/contrib/lite/kernels/transpose_conv.cc
parent6ba60e051409a5346c2aab21160c9c311de1cb03 (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214377809
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc21
1 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 6f2d98ede8..1c4a5ee91d 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -69,7 +69,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
- // Currenlty only supports float32.
+ // Currently only supports float32.
const TfLiteType data_type = input->type;
TF_LITE_ENSURE(context, data_type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, data_type);
@@ -117,19 +117,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
+ tflite::ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = padding_size.width;
+ op_params.padding_values.height = padding_size.height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
reference_ops::TransposeConv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
- stride_height, padding_size.width, padding_size.height,
- GetTensorData<float>(output), GetTensorDims(output),
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(output), GetTensorData<float>(output),
// Last two args specify im2col which reference_ops ignores.
// (Note this does not lead to a performance regression, as the
// previous optimized version was just a copy of the reference code.)
// TODO(b/110208176): Allocate im2col tensors and switch to
// optimized_ops.
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorShape(output), GetTensorData<float>(output));
break;
+ }
default:
context->ReportError(context, "Type %d, not currently supported.",
input->type);