diff options
author | 2018-06-14 11:40:28 -0700 | |
---|---|---|
committer | 2018-06-14 11:43:11 -0700 | |
commit | 8e4c4144817bea5ffd9255df48a78740fdb14f57 (patch) | |
tree | 91595cd3f71825b5f54210a8fb735df506bc48fa /tensorflow/contrib/lite/kernels/transpose_conv.cc | |
parent | 8f7afe01a583058726b03a0d849add35fcde41a3 (diff) |
Optimized implementation of transpose conv. Uses an im2col array and GEMM, similar to conv.
PiperOrigin-RevId: 200592004
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/transpose_conv.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index e83b1ec987..8b9deeed20 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -119,10 +119,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently only support float32. switch (input->type) { case kTfLiteFloat32: - optimized_ops::TransposeConv( + reference_ops::TransposeConv( GetTensorData<float>(input), GetTensorDims(input), GetTensorData<float>(weights), GetTensorDims(weights), stride_width, stride_height, padding_size.width, padding_size.height, + GetTensorData<float>(output), GetTensorDims(output), + // Last two args specify im2col which reference_ops ignores. + // (Note this does not lead to a performance regression, as the + // previous optimized version was just a copy of the reference code.) + // TODO(b/110208176): Allocate im2col tensors and switch to + // optimized_ops. GetTensorData<float>(output), GetTensorDims(output)); break; default: |