diff options
author | 2018-01-10 08:19:48 -0800 | |
---|---|---|
committer | 2018-01-10 08:25:57 -0800 | |
commit | 7255b9819f72b681aa66876ef0bd5ddfe67099f4 (patch) | |
tree | eff850e4cf5a9ef0a8253a797ba54e00cba024a9 /tensorflow/contrib/lite/kernels/pad.cc | |
parent | f0ed7bc454e1f24b4c984416b2fbac3a13883cd0 (diff) |
Add support for more types for Pad.
PiperOrigin-RevId: 181467627
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pad.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/pad.cc | 84 |
1 files changed, 52 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 5e90282a43..1a0d9d1505 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -54,6 +54,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); int dims = NumDimensions(op_context.input); TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. TF_LITE_ENSURE_EQ(context, dims, 4); @@ -77,41 +78,61 @@ template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); - // TODO(nupurgarg): Support different data types. - if (op_context.output->type == kTfLiteFloat32) { - std::vector<int> before_padding( - op_context.params->before_padding, - op_context.params->before_padding + op_context.params->num_dimensions); - std::vector<int> after_padding( - op_context.params->after_padding, - op_context.params->after_padding + op_context.params->num_dimensions); - - // TODO(nupurgarg): Change TOCO's implementation to use padding arrays - // in forward order (depth, width, height, batch). - // Converts from int[] = {depth, width, height, batch} to int[] = {batch, - // height, width, depth} to match TOCO's implementation of pad in - // referenced_ops.h and optimized_ops.h. - std::reverse(before_padding.begin(), before_padding.end()); - std::reverse(after_padding.begin(), after_padding.end()); - -#define TF_LITE_PAD(type) \ - type::Pad(GetTensorData<float>(op_context.input), \ + std::vector<int> before_padding( + op_context.params->before_padding, + op_context.params->before_padding + op_context.params->num_dimensions); + std::vector<int> after_padding( + op_context.params->after_padding, + op_context.params->after_padding + op_context.params->num_dimensions); + + // TODO(nupurgarg): Change TOCO's implementation to use padding arrays + // in forward order (depth, width, height, batch). + // Converts from int[] = {depth, width, height, batch} to int[] = {batch, + // height, width, depth} to match TOCO's implementation of pad in + // referenced_ops.h and optimized_ops.h. + std::reverse(before_padding.begin(), before_padding.end()); + std::reverse(after_padding.begin(), after_padding.end()); + +#define TF_LITE_PAD(type, scalar) \ + type::Pad(GetTensorData<scalar>(op_context.input), \ GetTensorDims(op_context.input), before_padding, after_padding, \ - GetTensorData<float>(op_context.output), \ + GetTensorData<scalar>(op_context.output), \ GetTensorDims(op_context.output)) - if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops); - } - if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops); - } -#undef TF_LITE_PAD - } else { - context->ReportError(context, "Inputs and outputs not all float types."); - return kTfLiteError; + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, float); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, uint8_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, int32_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, int64_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type is currently not supported by Pad."); + return kTfLiteError; } - +#undef TF_LITE_PAD return kTfLiteOk; } @@ -131,7 +152,6 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() { TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); - // return Register_PAD_REF(); } } // namespace builtin |