diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/transpose_conv.cc | 112 |
1 files changed, 17 insertions, 95 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 7182374a6f..a9baa5c698 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <unistd.h> #include <cassert> #include <cmath> #include <cstdio> @@ -22,7 +21,6 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" @@ -39,35 +37,9 @@ constexpr int kWeightsTensor = 1; constexpr int kDataInputTensor = 2; constexpr int kOutputTensor = 0; -const int kTensorNotAllocated = -1; - -struct OpData { - // IDs are the arbitrary identifiers used by TF Lite to identify and access - // memory buffers. - int im2col_id = kTensorNotAllocated; - - // im2col is the only temporary currently tracked, therefore always index 0. - // If more temporaries are added, they should be properly tracked. - int32_t im2col_index = 0; -}; - -void* Init(TfLiteContext* context, const char* buffer, size_t length) { - // This is a builtin op, so we don't use the contents in 'buffer', if any. - // Instead, we allocate a new object to use as scratch space for im2col, and - // to carry information from Prepare() to Eval(). - auto* data = new OpData; - eigen_support::IncrementUsageCounter(context); - return data; -} - -void Free(TfLiteContext* context, void* buffer) { - eigen_support::DecrementUsageCounter(context); - delete reinterpret_cast<OpData*>(buffer); -} - -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, - const TfLiteTensor* output_shape, - TfLiteTensor* output) { +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { // Currently only support int32 for output shape. if (output_shape->type != kTfLiteInt32) { context->ReportError(context, "Output shape is %d, not int32.", @@ -83,60 +55,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, return context->ResizeTensor(context, output, output_shape_array); } -// Allocate temporary im2col tensor. -static TfLiteStatus AllocateIm2colTensor(TfLiteContext* context, - TfLiteNode* node) { - OpData* data = reinterpret_cast<OpData*>(node->user_data); - if (data->im2col_id == kTensorNotAllocated) { - context->AddTensors(context, 1, &data->im2col_id); - } - - TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[data->im2col_index] = data->im2col_id; - - return kTfLiteOk; -} - -TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context, - const TfLiteTensor* output_shape, - const TfLiteTensor* weights, - const TfLiteTensor* input, - TfLiteTensor* im2col) { - if (output_shape->type != kTfLiteInt32) { - context->ReportError(context, "im2col shape is %d, not int32.", - output_shape->type); - return kTfLiteError; - } - TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4); - TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4); - im2col_shape_array->data[0] = output_shape->data.i32[0]; - im2col_shape_array->data[1] = output_shape->data.i32[1]; - im2col_shape_array->data[2] = output_shape->data.i32[2]; - const int input_depth = SizeOfDimension(input, 3); - const int filter_width = SizeOfDimension(weights, 1); - const int filter_height = SizeOfDimension(weights, 2); - im2col_shape_array->data[3] = input_depth * filter_height * filter_width; - - im2col->type = input->type; - im2col->allocation_type = kTfLiteArenaRw; - return context->ResizeTensor(context, im2col, im2col_shape_array); -} - TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TF_LITE_ENSURE_STATUS(AllocateIm2colTensor(context, node)); - const TfLiteTensor* output_shape = GetInput(context, node, kOutputShapeTensor); const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - OpData* user_data = reinterpret_cast<OpData*>(node->user_data); - TfLiteTensor* im2col = - &context->tensors[node->temporaries->data[user_data->im2col_index]]; TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); @@ -153,15 +80,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), SizeOfDimension(weights, 3)); - if (IsConstantTensor(output_shape)) { - TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output)); - TF_LITE_ENSURE_STATUS( - ResizeIm2ColTensor(context, output_shape, weights, input, im2col)); - } else { - // Defer resizing until Eval(). + if (!IsConstantTensor(output_shape)) { SetTensorToDynamic(output); + return kTfLiteOk; } - return kTfLiteOk; + return ResizeOutputShape(context, output_shape, output); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -170,19 +93,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - OpData* user_data = reinterpret_cast<OpData*>(node->user_data); - TfLiteTensor* im2col = - &context->tensors[node->temporaries->data[user_data->im2col_index]]; + const auto* params = reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data); if (IsDynamicTensor(output)) { TF_LITE_ENSURE_OK(context, - ResizeOutputTensor(context, output_shape, output)); - } - if (IsDynamicTensor(im2col)) { - TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape, - weights, input, im2col)); + ResizeOutputShape(context, output_shape, output)); } // Get height and width of the output image. @@ -201,12 +118,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently only support float32. switch (input->type) { case kTfLiteFloat32: - optimized_ops::TransposeConv( + reference_ops::TransposeConv( GetTensorData<float>(input), GetTensorDims(input), GetTensorData<float>(weights), GetTensorDims(weights), stride_width, stride_height, padding_size.width, padding_size.height, GetTensorData<float>(output), GetTensorDims(output), - GetTensorData<float>(im2col), GetTensorDims(im2col)); + // Last two args specify im2col which reference_ops ignores. + // (Note this does not lead to a performance regression, as the + // previous optimized version was just a copy of the reference code.) + // TODO(b/110208176): Allocate im2col tensors and switch to + // optimized_ops. + GetTensorData<float>(output), GetTensorDims(output)); break; default: context->ReportError(context, "Type %d, not currently supported.", @@ -219,8 +141,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace transpose_conv TfLiteRegistration* Register_TRANSPOSE_CONV() { - static TfLiteRegistration r = {transpose_conv::Init, transpose_conv::Free, - transpose_conv::Prepare, transpose_conv::Eval}; + static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare, + transpose_conv::Eval}; return &r; } |