aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/comparisons.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 12:37:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 16:58:51 -0700
commit6f3a890d91e6dbeb811aed23d0eb59abaa8c469f (patch)
treef05bad7a6c979613b43c495a042a675cbc59a13a /tensorflow/contrib/lite/kernels/comparisons.cc
parentc3fef21c4ddf34fd68ab2cd44b0be497b5303b4e (diff)
Adding Greater/GreaterEqual/LessEqual ops to complement Less.
PiperOrigin-RevId: 195704492
Diffstat (limited to 'tensorflow/contrib/lite/kernels/comparisons.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc160
1 files changed, 119 insertions, 41 deletions
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 87c413cb98..2885ce032b 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -28,7 +28,7 @@ constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
-TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -56,61 +56,139 @@ TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
-TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ requires_broadcast \
+ ? reference_ops::Broadcast##opname( \
+ GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output)) \
+ : reference_ops::opname( \
+ GetTensorData<type>(input1), GetTensorDims(input1), \
+ GetTensorData<type>(input2), GetTensorDims(input2), \
+ GetTensorData<bool>(output), GetTensorDims(output));
+
+TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Greater, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
-#define TF_LITE_LESS(type, opname) \
- reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Less, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Less, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Less, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
// TODO(renjieliu): Support quantized data.
- if (requires_broadcast) {
- switch (input1->type) {
- case kTfLiteFloat32:
- TF_LITE_LESS(float, BroadcastLess);
- break;
- case kTfLiteInt32:
- TF_LITE_LESS(int32_t, BroadcastLess);
- break;
- case kTfLiteInt64:
- TF_LITE_LESS(int64_t, BroadcastLess);
- break;
- default:
- context->ReportError(context,
- "Does not support type other than float|int");
- return kTfLiteError;
- }
- } else {
- switch (input1->type) {
- case kTfLiteFloat32:
- TF_LITE_LESS(float, Less);
- break;
- case kTfLiteInt32:
- TF_LITE_LESS(int32_t, Less);
- break;
- case kTfLiteInt64:
- TF_LITE_LESS(int64_t, Less);
- break;
- default:
- context->ReportError(context,
- "Does not support type other than float|int");
- return kTfLiteError;
- }
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, LessEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type other than float|int");
+ return kTfLiteError;
}
-#undef TF_LITE_LESS
return kTfLiteOk;
}
} // namespace comparisons
+TfLiteRegistration* Register_GREATER() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::GreaterEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_GREATER_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::GreaterEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_LESS() {
- static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare,
- comparisons::LessEval};
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_LESS_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::LessEqualEval};
return &r;
}