aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/pad.cc
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-01-10 08:19:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 08:25:57 -0800
commit7255b9819f72b681aa66876ef0bd5ddfe67099f4 (patch)
treeeff850e4cf5a9ef0a8253a797ba54e00cba024a9 /tensorflow/contrib/lite/kernels/pad.cc
parentf0ed7bc454e1f24b4c984416b2fbac3a13883cd0 (diff)
Add support for more types for Pad.
PiperOrigin-RevId: 181467627
Diffstat (limited to 'tensorflow/contrib/lite/kernels/pad.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc84
1 files changed, 52 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 5e90282a43..1a0d9d1505 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -54,6 +54,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
int dims = NumDimensions(op_context.input);
TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions);
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
// TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
TF_LITE_ENSURE_EQ(context, dims, 4);
@@ -77,41 +78,61 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
- // TODO(nupurgarg): Support different data types.
- if (op_context.output->type == kTfLiteFloat32) {
- std::vector<int> before_padding(
- op_context.params->before_padding,
- op_context.params->before_padding + op_context.params->num_dimensions);
- std::vector<int> after_padding(
- op_context.params->after_padding,
- op_context.params->after_padding + op_context.params->num_dimensions);
-
- // TODO(nupurgarg): Change TOCO's implementation to use padding arrays
- // in forward order (depth, width, height, batch).
- // Converts from int[] = {depth, width, height, batch} to int[] = {batch,
- // height, width, depth} to match TOCO's implementation of pad in
- // referenced_ops.h and optimized_ops.h.
- std::reverse(before_padding.begin(), before_padding.end());
- std::reverse(after_padding.begin(), after_padding.end());
-
-#define TF_LITE_PAD(type) \
- type::Pad(GetTensorData<float>(op_context.input), \
+ std::vector<int> before_padding(
+ op_context.params->before_padding,
+ op_context.params->before_padding + op_context.params->num_dimensions);
+ std::vector<int> after_padding(
+ op_context.params->after_padding,
+ op_context.params->after_padding + op_context.params->num_dimensions);
+
+ // TODO(nupurgarg): Change TOCO's implementation to use padding arrays
+ // in forward order (depth, width, height, batch).
+ // Converts from int[] = {depth, width, height, batch} to int[] = {batch,
+ // height, width, depth} to match TOCO's implementation of pad in
+ // referenced_ops.h and optimized_ops.h.
+ std::reverse(before_padding.begin(), before_padding.end());
+ std::reverse(after_padding.begin(), after_padding.end());
+
+#define TF_LITE_PAD(type, scalar) \
+ type::Pad(GetTensorData<scalar>(op_context.input), \
GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<float>(op_context.output), \
+ GetTensorData<scalar>(op_context.output), \
GetTensorDims(op_context.output))
- if (kernel_type == kReference) {
- TF_LITE_PAD(reference_ops);
- }
- if (kernel_type == kGenericOptimized) {
- TF_LITE_PAD(optimized_ops);
- }
-#undef TF_LITE_PAD
- } else {
- context->ReportError(context, "Inputs and outputs not all float types.");
- return kTfLiteError;
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_PAD(reference_ops, float);
+ } else if (kernel_type == kGenericOptimized) {
+ TF_LITE_PAD(optimized_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_PAD(reference_ops, uint8_t);
+ } else if (kernel_type == kGenericOptimized) {
+ TF_LITE_PAD(optimized_ops, uint8_t);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_PAD(reference_ops, int32_t);
+ } else if (kernel_type == kGenericOptimized) {
+ TF_LITE_PAD(optimized_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_PAD(reference_ops, int64_t);
+ } else if (kernel_type == kGenericOptimized) {
+ TF_LITE_PAD(optimized_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context, "Type is currently not supported by Pad.");
+ return kTfLiteError;
}
-
+#undef TF_LITE_PAD
return kTfLiteOk;
}
@@ -131,7 +152,6 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() {
TfLiteRegistration* Register_PAD() {
return Register_PAD_GENERIC_OPT();
- // return Register_PAD_REF();
}
} // namespace builtin