aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/mul.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-30 13:58:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 14:08:45 -0800
commit36f3a3b31ea9c32c64f1f4af543c75692d338876 (patch)
treeb1a9882c5becea0fec90e07f9cc8d58ec457f317 /tensorflow/contrib/lite/kernels/mul.cc
parent548df15375488fc06ff663670f88734f3ece4814 (diff)
Add and Mul support broadcasting.
PiperOrigin-RevId: 183886920
Diffstat (limited to 'tensorflow/contrib/lite/kernels/mul.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc99
1 files changed, 66 insertions, 33 deletions
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 81c73f2523..54575019de 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2));
- for (int i = 0; i < NumDimensions(input1); ++i) {
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i),
- SizeOfDimension(input2, i));
- }
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+ output->type = input2->type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
- TF_LITE_ENSURE_EQ(context, input1->type, output->type);
- TF_LITE_ENSURE_EQ(context, input2->type, output->type);
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
- TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims);
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteMulParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRangeFloat(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_MUL(type) \
- type::Mul(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
+ GetTensorData<float>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<float>(output), GetTensorDims(output))
if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(reference_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(reference_ops, Mul);
+ }
} else {
- TF_LITE_MUL(optimized_ops);
+ if (data->requires_broadcast) {
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
}
#undef TF_LITE_MUL
}
template <KernelType kernel_type>
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, TfLiteTensor* input1,
- TfLiteTensor* input2, TfLiteTensor* output) {
+ TfLiteMulParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
auto input1_offset = -input1->params.zero_point;
auto input2_offset = -input2->params.zero_point;
auto output_offset = output->params.zero_point;
@@ -98,17 +127,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeUint8(params->activation, output,
&output_activation_min, &output_activation_max);
-#define TF_LITE_MUL(type) \
- type::BroadcastMul(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, output_offset, \
- output_multiplier, output_shift, output_activation_min, \
- output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ input1_offset, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, output_offset, \
+ output_multiplier, output_shift, output_activation_min, \
+ output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ // The quantized version of Mul doesn't support activations, so we
+ // always use BroadcastMul.
if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops);
+ TF_LITE_MUL(reference_ops, BroadcastMul);
} else {
- TF_LITE_MUL(optimized_ops);
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
}
#undef TF_LITE_MUL
}
@@ -116,15 +147,17 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, input1, input2, output);
+ EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
- EvalQuantized<kernel_type>(context, node, params, input1, input2, output);
+ EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
+ output);
} else {
context->ReportError(context,
"Mul only supports FLOAT32 and quantized UINT8 now.");
@@ -137,19 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace mul
TfLiteRegistration* Register_MUL_REF() {
- static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
mul::Eval<mul::kReference>};
return &r;
}
TfLiteRegistration* Register_MUL_GENERIC_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
mul::Eval<mul::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_MUL_NEON_OPT() {
- static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare,
+ static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
mul::Eval<mul::kNeonOptimized>};
return &r;
}