aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/add.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 13:58:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 14:08:45 -0800
commit36f3a3b31ea9c32c64f1f4af543c75692d338876 (patch)
treeb1a9882c5becea0fec90e07f9cc8d58ec457f317 /tensorflow/contrib/lite/kernels/add.cc
parent548df15375488fc06ff663670f88734f3ece4814 (diff)
Add and Mul support broadcasting.
PiperOrigin-RevId: 183886920
Diffstat (limited to 'tensorflow/contrib/lite/kernels/add.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc102
1 files changed, 67 insertions, 35 deletions
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index fb5764f280..63ea89df56 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
- for (int i = 0; i < NumDimensions(input1); ++i) {
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
- SizeOfDimension(input2, i));
- }
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+ output->type = input2->type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
- TF_LITE_ENSURE_EQ(context, input1->type, output->type);
- TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
- TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteAddParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_ADD(type) \
- type::Add(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_ADD(type, opname) \
+ type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd);
+ } else {
+ TF_LITE_ADD(reference_ops, Add);
+ }
} else {
- TF_LITE_ADD(optimized_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add);
+ }
}
#undef TF_LITE_ADD
}
template <KernelType kernel_type>
void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteAddParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
auto output_offset = output->params.zero_point;
@@ -112,19 +141,20 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_ADD(type) \
- type::BroadcastAdd( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), input2_offset, \
- input2_multiplier, input2_shift, output_offset, output_multiplier, \
- output_shift, output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
-
+#define TF_LITE_ADD(type, opname) \
+ type::opname(left_shift, GetTensorData<uint8_t>(input1), \
+ GetTensorDims(input1), input1_offset, input1_multiplier, \
+ input1_shift, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, input2_multiplier, \
+ input2_shift, output_offset, output_multiplier, output_shift, \
+ output_activation_min, output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+ // The quantized version of Add doesn't support activations, so we
+ // always use BroadcastAdd.
if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops);
+ TF_LITE_ADD(reference_ops, BroadcastAdd);
} else {
- TF_LITE_ADD(optimized_ops);
+ TF_LITE_ADD(optimized_ops, BroadcastAdd);
}
#undef TF_LITE_ADD
}
@@ -132,15 +162,17 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
- EvalAddFloat<kernel_type>(context, node, params, input1, input2, output);
+ EvalAddFloat<kernel_type>(context, node, params, data, input1, input2,
+ output);
} else if (output->type == kTfLiteUInt8) {
- EvalAddQuantized<kernel_type>(context, node, params, input1, input2,
+ EvalAddQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
context->ReportError(context,
@@ -154,19 +186,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace add
TfLiteRegistration* Register_ADD_REF() {
- static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ static TfLiteRegistration r = {add::Init, add::Free, add::Prepare,
add::Eval<add::kReference>};
return &r;
}
TfLiteRegistration* Register_ADD_GENERIC_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ static TfLiteRegistration r = {add::Init, add::Free, add::Prepare,
add::Eval<add::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_ADD_NEON_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, add::Prepare,
+ static TfLiteRegistration r = {add::Init, add::Free, add::Prepare,
add::Eval<add::kNeonOptimized>};
return &r;
}