aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/transpose_conv.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 11:40:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 11:43:11 -0700
commit8e4c4144817bea5ffd9255df48a78740fdb14f57 (patch)
tree91595cd3f71825b5f54210a8fb735df506bc48fa /tensorflow/contrib/lite/kernels/transpose_conv.cc
parent8f7afe01a583058726b03a0d849add35fcde41a3 (diff)
Optimized implementation of transpose conv. Uses an im2col array and GEMM, similar to conv.
PiperOrigin-RevId: 200592004
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index e83b1ec987..8b9deeed20 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -119,10 +119,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
case kTfLiteFloat32:
- optimized_ops::TransposeConv(
+ reference_ops::TransposeConv(
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
stride_height, padding_size.width, padding_size.height,
+ GetTensorData<float>(output), GetTensorDims(output),
+ // Last two args specify im2col which reference_ops ignores.
+ // (Note this does not lead to a performance regression, as the
+ // previous optimized version was just a copy of the reference code.)
+ // TODO(b/110208176): Allocate im2col tensors and switch to
+ // optimized_ops.
GetTensorData<float>(output), GetTensorDims(output));
break;
default: