diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-07 12:37:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 16:58:51 -0700 |
commit | 6f3a890d91e6dbeb811aed23d0eb59abaa8c469f (patch) | |
tree | f05bad7a6c979613b43c495a042a675cbc59a13a /tensorflow/contrib/lite/kernels/comparisons.cc | |
parent | c3fef21c4ddf34fd68ab2cd44b0be497b5303b4e (diff) |
Adding Greater/GreaterEqual/LessEqual ops to complement Less.
PiperOrigin-RevId: 195704492
Diffstat (limited to 'tensorflow/contrib/lite/kernels/comparisons.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/comparisons.cc | 160 |
1 files changed, 119 insertions, 41 deletions
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 87c413cb98..2885ce032b 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -28,7 +28,7 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; -TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -56,61 +56,139 @@ TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { +#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ + requires_broadcast \ + ? reference_ops::Broadcast##opname( \ + GetTensorData<type>(input1), GetTensorDims(input1), \ + GetTensorData<type>(input2), GetTensorDims(input2), \ + GetTensorData<bool>(output), GetTensorDims(output)) \ + : reference_ops::opname( \ + GetTensorData<type>(input1), GetTensorDims(input1), \ + GetTensorData<type>(input2), GetTensorDims(input2), \ + GetTensorData<bool>(output), GetTensorDims(output)); + +TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Greater, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} +TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} -#define TF_LITE_LESS(type, opname) \ - reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \ - GetTensorData<type>(input2), GetTensorDims(input2), \ - GetTensorData<bool>(output), GetTensorDims(output)); +TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Less, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Less, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} +TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. - if (requires_broadcast) { - switch (input1->type) { - case kTfLiteFloat32: - TF_LITE_LESS(float, BroadcastLess); - break; - case kTfLiteInt32: - TF_LITE_LESS(int32_t, BroadcastLess); - break; - case kTfLiteInt64: - TF_LITE_LESS(int64_t, BroadcastLess); - break; - default: - context->ReportError(context, - "Does not support type other than float|int"); - return kTfLiteError; - } - } else { - switch (input1->type) { - case kTfLiteFloat32: - TF_LITE_LESS(float, Less); - break; - case kTfLiteInt32: - TF_LITE_LESS(int32_t, Less); - break; - case kTfLiteInt64: - TF_LITE_LESS(int64_t, Less); - break; - default: - context->ReportError(context, - "Does not support type other than float|int"); - return kTfLiteError; - } + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; } -#undef TF_LITE_LESS return kTfLiteOk; } } // namespace comparisons +TfLiteRegistration* Register_GREATER() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::GreaterEval}; + return &r; +} + +TfLiteRegistration* Register_GREATER_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::GreaterEqualEval}; + return &r; +} + TfLiteRegistration* Register_LESS() { - static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare, - comparisons::LessEval}; + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval}; + return &r; +} + +TfLiteRegistration* Register_LESS_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::LessEqualEval}; return &r; } |