aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/pad.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-04 18:49:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-05 08:45:32 -0700
commit5fb53fe69afe7f9106a8bcb5632cea23cf227d78 (patch)
tree2658e05fa2481666efbea50c56909abfec3f938f /tensorflow/contrib/lite/kernels/pad.cc
parentdd5ef1b9fc22b37e5eec87d659a3af064ca54b8b (diff)
add support for PadV2
PiperOrigin-RevId: 195503894
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pad.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc110
1 files changed, 81 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 4f9449a225..9e1e4658e9 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -37,9 +37,15 @@ struct PadContext {
PadContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, 0);
paddings = GetInput(context, node, 1);
+ if (NumInputs(node) == 3) {
+ constant_values = GetOptionalInputTensor(context, node, 2);
+ } else {
+ constant_values = nullptr;
+ }
output = GetOutput(context, node, 0);
dims = NumDimensions(input);
}
+ TfLiteTensor* constant_values;
TfLiteTensor* input;
TfLiteTensor* paddings;
TfLiteTensor* output;
@@ -76,11 +82,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
PadContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ if (op_context.constant_values != nullptr) {
+ TF_LITE_ENSURE_EQ(context, op_context.input->type,
+ op_context.constant_values->type);
+ }
// TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
@@ -98,6 +108,11 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
+ if (op_context.constant_values != nullptr) {
+ // Ensure that constant_values is a scalar.
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
+ }
+
// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
@@ -119,48 +134,70 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::Pad(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ type::PadV2(GetTensorData<scalar>(op_context.input), \
+ GetTensorDims(op_context.input), before_padding, after_padding, \
+ GetTensorData<scalar>(op_context.output), \
+ GetTensorDims(op_context.output), pad_value)
switch (op_context.input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
+ float pad_value = op_context.constant_values == nullptr
+ ? 0.f
+ : *GetTensorData<float>(op_context.constant_values);
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, float, 0);
+ TF_LITE_PAD(reference_ops, float, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, float, 0);
+ TF_LITE_PAD(optimized_ops, float, pad_value);
+ }
+ } break;
+ case kTfLiteUInt8: {
+ uint8_t pad_value;
+ if (op_context.constant_values == nullptr) {
+ // Quantized Pad requires that 0 is represented in the quantized
+ // range.
+ TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
+ std::numeric_limits<uint8_t>::min());
+ TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
+ std::numeric_limits<uint8_t>::max());
+ pad_value = static_cast<uint8_t>(op_context.output->params.zero_point);
+ } else {
+ // Quantized Pad requires that 'constant_values' is represented in the
+ // same quantized range as the input and output tensors.
+ TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
+ op_context.constant_values->params.zero_point);
+ TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
+ op_context.constant_values->params.scale);
+ pad_value = *GetTensorData<uint8_t>(op_context.constant_values);
}
- break;
- case kTfLiteUInt8:
- // Quantized Pad requires that 0 is represented in the quantized range.
- TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
- std::numeric_limits<uint8_t>::min());
- TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
- std::numeric_limits<uint8_t>::max());
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, uint8_t,
- op_context.output->params.zero_point);
+ TF_LITE_PAD(reference_ops, uint8_t, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, uint8_t,
- op_context.output->params.zero_point);
+ TF_LITE_PAD(optimized_ops, uint8_t, pad_value);
}
- break;
- case kTfLiteInt32:
+ } break;
+ case kTfLiteInt32: {
+ int32_t pad_value =
+ op_context.constant_values == nullptr
+ ? 0
+ : *GetTensorData<int32_t>(op_context.constant_values);
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, int32_t, 0);
+ TF_LITE_PAD(reference_ops, int32_t, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, int32_t, 0);
+ TF_LITE_PAD(optimized_ops, int32_t, pad_value);
}
- break;
- case kTfLiteInt64:
+ } break;
+ case kTfLiteInt64: {
+ int64_t pad_value =
+ op_context.constant_values == nullptr
+ ? 0L
+ : *GetTensorData<int64_t>(op_context.constant_values);
if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops, int64_t, 0);
+ TF_LITE_PAD(reference_ops, int64_t, pad_value);
} else if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops, int64_t, 0);
+ TF_LITE_PAD(optimized_ops, int64_t, pad_value);
}
- break;
+ } break;
default:
context->ReportError(context, "Type is currently not supported by Pad.");
return kTfLiteError;
@@ -185,6 +222,21 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() {
TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
+// Also register Pad as PadV2.
+TfLiteRegistration* Register_PADV2_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
+ pad::Eval<pad::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_PADV2_GENERIC_OPT() {
+ static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
+ pad::Eval<pad::kGenericOptimized>};
+ return &r;
+}
+
+TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); }
+
} // namespace builtin
} // namespace ops
} // namespace tflite