diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-04 18:49:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-05 08:45:32 -0700 |
commit | 5fb53fe69afe7f9106a8bcb5632cea23cf227d78 (patch) | |
tree | 2658e05fa2481666efbea50c56909abfec3f938f /tensorflow/contrib/lite/kernels/pad.cc | |
parent | dd5ef1b9fc22b37e5eec87d659a3af064ca54b8b (diff) |
add support for PadV2
PiperOrigin-RevId: 195503894
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pad.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/pad.cc | 110 |
1 files changed, 81 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 4f9449a225..9e1e4658e9 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -37,9 +37,15 @@ struct PadContext { PadContext(TfLiteContext* context, TfLiteNode* node) { input = GetInput(context, node, 0); paddings = GetInput(context, node, 1); + if (NumInputs(node) == 3) { + constant_values = GetOptionalInputTensor(context, node, 2); + } else { + constant_values = nullptr; + } output = GetOutput(context, node, 0); dims = NumDimensions(input); } + TfLiteTensor* constant_values; TfLiteTensor* input; TfLiteTensor* paddings; TfLiteTensor* output; @@ -76,11 +82,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); PadContext op_context(context, node); TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + if (op_context.constant_values != nullptr) { + TF_LITE_ENSURE_EQ(context, op_context.input->type, + op_context.constant_values->type); + } // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. TF_LITE_ENSURE_EQ(context, op_context.dims, 4); @@ -98,6 +108,11 @@ template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); + if (op_context.constant_values != nullptr) { + // Ensure that constant_values is a scalar. + TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1); + } + // Resize the output tensor if the output tensor is dynamic. if (IsDynamicTensor(op_context.output)) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); @@ -119,48 +134,70 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { after_padding.push_back(paddings_data[idx * 2 + 1]); } -#define TF_LITE_PAD(type, scalar, pad_value) \ - type::Pad(GetTensorData<scalar>(op_context.input), \ - GetTensorDims(op_context.input), before_padding, after_padding, \ - GetTensorData<scalar>(op_context.output), \ - GetTensorDims(op_context.output), pad_value) +#define TF_LITE_PAD(type, scalar, pad_value) \ + type::PadV2(GetTensorData<scalar>(op_context.input), \ + GetTensorDims(op_context.input), before_padding, after_padding, \ + GetTensorData<scalar>(op_context.output), \ + GetTensorDims(op_context.output), pad_value) switch (op_context.input->type) { - case kTfLiteFloat32: + case kTfLiteFloat32: { + float pad_value = op_context.constant_values == nullptr + ? 0.f + : *GetTensorData<float>(op_context.constant_values); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, float, 0); + TF_LITE_PAD(reference_ops, float, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, float, 0); + TF_LITE_PAD(optimized_ops, float, pad_value); + } + } break; + case kTfLiteUInt8: { + uint8_t pad_value; + if (op_context.constant_values == nullptr) { + // Quantized Pad requires that 0 is represented in the quantized + // range. + TF_LITE_ENSURE(context, op_context.output->params.zero_point >= + std::numeric_limits<uint8_t>::min()); + TF_LITE_ENSURE(context, op_context.output->params.zero_point <= + std::numeric_limits<uint8_t>::max()); + pad_value = static_cast<uint8_t>(op_context.output->params.zero_point); + } else { + // Quantized Pad requires that 'constant_values' is represented in the + // same quantized range as the input and output tensors. + TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, + op_context.constant_values->params.zero_point); + TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, + op_context.constant_values->params.scale); + pad_value = *GetTensorData<uint8_t>(op_context.constant_values); } - break; - case kTfLiteUInt8: - // Quantized Pad requires that 0 is represented in the quantized range. - TF_LITE_ENSURE(context, op_context.output->params.zero_point >= - std::numeric_limits<uint8_t>::min()); - TF_LITE_ENSURE(context, op_context.output->params.zero_point <= - std::numeric_limits<uint8_t>::max()); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, uint8_t, - op_context.output->params.zero_point); + TF_LITE_PAD(reference_ops, uint8_t, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, uint8_t, - op_context.output->params.zero_point); + TF_LITE_PAD(optimized_ops, uint8_t, pad_value); } - break; - case kTfLiteInt32: + } break; + case kTfLiteInt32: { + int32_t pad_value = + op_context.constant_values == nullptr + ? 0 + : *GetTensorData<int32_t>(op_context.constant_values); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, int32_t, 0); + TF_LITE_PAD(reference_ops, int32_t, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, int32_t, 0); + TF_LITE_PAD(optimized_ops, int32_t, pad_value); } - break; - case kTfLiteInt64: + } break; + case kTfLiteInt64: { + int64_t pad_value = + op_context.constant_values == nullptr + ? 0L + : *GetTensorData<int64_t>(op_context.constant_values); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, int64_t, 0); + TF_LITE_PAD(reference_ops, int64_t, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, int64_t, 0); + TF_LITE_PAD(optimized_ops, int64_t, pad_value); } - break; + } break; default: context->ReportError(context, "Type is currently not supported by Pad."); return kTfLiteError; @@ -185,6 +222,21 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() { TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); } +// Also register Pad as PadV2. +TfLiteRegistration* Register_PADV2_REF() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval<pad::kReference>}; + return &r; +} + +TfLiteRegistration* Register_PADV2_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval<pad::kGenericOptimized>}; + return &r; +} + +TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); } + } // namespace builtin } // namespace ops } // namespace tflite |