diff options
author | 2018-05-07 12:37:36 -0700 | |
---|---|---|
committer | 2018-05-07 16:58:51 -0700 | |
commit | 6f3a890d91e6dbeb811aed23d0eb59abaa8c469f (patch) | |
tree | f05bad7a6c979613b43c495a042a675cbc59a13a /tensorflow | |
parent | c3fef21c4ddf34fd68ab2cd44b0be497b5303b4e (diff) |
Adding Greater/GreaterEqual/LessEqual ops to complement Less.
PiperOrigin-RevId: 195704492
Diffstat (limited to 'tensorflow')
20 files changed, 1051 insertions, 103 deletions
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index d66b72843a..778933f569 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -86,6 +86,9 @@ typedef enum { kTfLiteBuiltinLess = 58, kTfLiteBuiltinNeg = 59, kTfLiteBuiltinPadv2 = 60, + kTfLiteBuiltinGreater = 61, + kTfLiteBuiltinGreaterEqual = 62, + kTfLiteBuiltinLessEqual = 63, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 0051ee84ec..fc57b8f28b 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -281,6 +281,32 @@ Options { } ``` +**GREATER** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than the corresponding element of the second tensor. +} +``` + +**GREATER_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than or equal to the corresponding element of the second tensor. +} +``` + **L2_NORMALIZATION** ``` @@ -325,6 +351,19 @@ Outputs { } ``` +**LESS_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is less + than or equal to the corresponding element of the second tensor. +} +``` + **LOCAL_RESPONSE_NORMALIZATION** ``` diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 87c413cb98..2885ce032b 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -28,7 +28,7 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; -TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -56,61 +56,139 @@ TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { +#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ + requires_broadcast \ + ? reference_ops::Broadcast##opname( \ + GetTensorData<type>(input1), GetTensorDims(input1), \ + GetTensorData<type>(input2), GetTensorDims(input2), \ + GetTensorData<bool>(output), GetTensorDims(output)) \ + : reference_ops::opname( \ + GetTensorData<type>(input1), GetTensorDims(input1), \ + GetTensorData<type>(input2), GetTensorDims(input2), \ + GetTensorData<bool>(output), GetTensorDims(output)); + +TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Greater, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} +TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} -#define TF_LITE_LESS(type, opname) \ - reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \ - GetTensorData<type>(input2), GetTensorDims(input2), \ - GetTensorData<bool>(output), GetTensorDims(output)); +TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Less, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Less, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} +TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. - if (requires_broadcast) { - switch (input1->type) { - case kTfLiteFloat32: - TF_LITE_LESS(float, BroadcastLess); - break; - case kTfLiteInt32: - TF_LITE_LESS(int32_t, BroadcastLess); - break; - case kTfLiteInt64: - TF_LITE_LESS(int64_t, BroadcastLess); - break; - default: - context->ReportError(context, - "Does not support type other than float|int"); - return kTfLiteError; - } - } else { - switch (input1->type) { - case kTfLiteFloat32: - TF_LITE_LESS(float, Less); - break; - case kTfLiteInt32: - TF_LITE_LESS(int32_t, Less); - break; - case kTfLiteInt64: - TF_LITE_LESS(int64_t, Less); - break; - default: - context->ReportError(context, - "Does not support type other than float|int"); - return kTfLiteError; - } + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; } -#undef TF_LITE_LESS return kTfLiteOk; } } // namespace comparisons +TfLiteRegistration* Register_GREATER() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::GreaterEval}; + return &r; +} + +TfLiteRegistration* Register_GREATER_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::GreaterEqualEval}; + return &r; +} + TfLiteRegistration* Register_LESS() { - static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare, - comparisons::LessEval}; + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval}; + return &r; +} + +TfLiteRegistration* Register_LESS_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::LessEqualEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index da2d7f8589..835d238d36 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -23,6 +23,139 @@ namespace { using ::testing::ElementsAreArray; +class GreaterOpModel : public SingleOpModel { + public: + GreaterOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type) { + input1_ = AddInput(input_type); + input2_ = AddInput(input_type); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(ComparisonsTest, GreaterFloat) { + GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterInt) { + GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, + false, true, false, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); +} + +class GreaterEqualOpModel : public SingleOpModel { + public: + GreaterEqualOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type) { + input1_ = AddInput(input_type); + input2_ = AddInput(input_type); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, + BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(ComparisonsTest, GreaterEqualFloat) { + GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterEqualInt) { + GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterEqualBroadcast) { + GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { + GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, + false, true, true, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); +} + class LessOpModel : public SingleOpModel { public: LessOpModel(std::initializer_list<int> input1_shape, @@ -47,7 +180,7 @@ class LessOpModel : public SingleOpModel { int output_; }; -TEST(ArgMaxOpTest, LessFloat) { +TEST(ComparisonsTest, LessFloat) { LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); @@ -57,7 +190,7 @@ TEST(ArgMaxOpTest, LessFloat) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } -TEST(ArgMaxOpTest, LessInt) { +TEST(ComparisonsTest, LessInt) { LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5}); @@ -67,7 +200,7 @@ TEST(ArgMaxOpTest, LessInt) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } -TEST(ArgMaxOpTest, LessBroadcast) { +TEST(ComparisonsTest, LessBroadcast) { LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {7}); @@ -77,7 +210,7 @@ TEST(ArgMaxOpTest, LessBroadcast) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } -TEST(ArgMaxOpTest, LessBroadcastTwoD) { +TEST(ComparisonsTest, LessBroadcastTwoD) { LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); @@ -88,6 +221,72 @@ TEST(ArgMaxOpTest, LessBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); } +class LessEqualOpModel : public SingleOpModel { + public: + LessEqualOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type) { + input1_ = AddInput(input_type); + input2_ = AddInput(input_type); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(ComparisonsTest, LessEqualFloat) { + LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, LessEqualInt) { + LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, LessEqualBroadcast) { + LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, LessEqualBroadcastTwoD) { + LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, + true, false, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index df29172f83..7ec4782f96 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -157,6 +157,7 @@ cc_library( ":quantization_util", ":strided_slice_logic", ":types", + ":reference_base", ":round", "//third_party/eigen3", "@gemmlowp", diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 18601df22c..ede95dfee0 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -113,6 +113,20 @@ inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier, right_shift); } +template <typename T> +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned<T>::value, + "Only unsigned integer types handled."); + const T one_in_leading_positive = static_cast<T>(1) + << (std::numeric_limits<T>::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index e2a1a6996d..c506c5636c 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -31,6 +31,7 @@ limitations under the License. #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" @@ -38,6 +39,16 @@ limitations under the License. namespace tflite { namespace optimized_ops { +// Unoptimized reference ops: +using reference_ops::BroadcastGreater; +using reference_ops::BroadcastGreaterEqual; +using reference_ops::BroadcastLess; +using reference_ops::BroadcastLessEqual; +using reference_ops::Greater; +using reference_ops::GreaterEqual; +using reference_ops::Less; +using reference_ops::LessEqual; + // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to // construct the suitable Eigen type for the constness of the diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 05e6ca8e7e..93dba1cc8e 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -35,35 +35,6 @@ limitations under the License. namespace tflite { namespace reference_ops { -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { - using gemmlowp::RoundingDivideByPOT; - using gemmlowp::SaturatingRoundingDoublingHighMul; - return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); -} - -inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( - int32 x, int32 quantized_multiplier, int left_shift) { - using gemmlowp::SaturatingRoundingDoublingHighMul; - return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), - quantized_multiplier); -} - -template <typename T> -int CountLeadingZeros(T integer_input) { - static_assert(std::is_unsigned<T>::value, - "Only unsigned integer types handled."); - const T one_in_leading_positive = static_cast<T>(1) - << (std::numeric_limits<T>::digits - 1); - int leading_zeros = 0; - while (integer_input < one_in_leading_positive) { - integer_input <<= 1; - ++leading_zeros; - } - return leading_zeros; -} - // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -3614,17 +3585,29 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } template <typename T> -inline void Less(int64_t num_elements, const T* input1, const T* input2, - bool* output) { - for (int64_t i = 0; i < num_elements; ++i) { - output[i] = input1[i] < input2[i]; - } +inline bool GreaterFn(T lhs, T rhs) { + return lhs > rhs; +} +template <typename T> +inline bool GreaterEqualFn(T lhs, T rhs) { + return lhs >= rhs; +} +template <typename T> +inline bool LessFn(T lhs, T rhs) { + return lhs < rhs; +} +template <typename T> +inline bool LessEqualFn(T lhs, T rhs) { + return lhs <= rhs; } template <typename T> -inline void Less(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - bool* output_data, const Dims<4>& output_dims) { +using ComparisonFn = bool (*)(T, T); + +template <typename T, ComparisonFn<T> F> +inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims) { const int64_t batches = MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); const int64_t height = @@ -3633,31 +3616,149 @@ inline void Less(const T* input1_data, const Dims<4>& input1_dims, MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); const int64_t depth = MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - Less(batches * height * width * depth, input1_data, input2_data, output_data); + for (int64_t i = 0; i < batches * height * width * depth; ++i) { + output_data[i] = F(input1_data[i], input2_data[i]); + } } -template <typename T1, typename T2> -inline void BroadcastLess(T1* input1_data, const Dims<4>& input1_dims, - T2* input2_data, const Dims<4>& input2_dims, - bool* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastLess"); +template <typename T, ComparisonFn<T> F> +inline void Comparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, bool* output_data, + const Dims<4>& output_dims) { + const int64_t batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int64_t height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int64_t width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int64_t depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int64_t i = 0; i < batches * height * width * depth; ++i) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + output_data[i] = F(scaled_input1_val, scaled_input2_val); + } +} + +template <typename T, ComparisonFn<T> F> +inline void BroadcastComparison(const T* input1_data, + const Dims<4>& input1_dims, + const T* input2_data, + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + F(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} +template <typename T, ComparisonFn<T> F> +inline void BroadcastComparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 input2_multiplier, int input2_shift, + bool* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); for (int b = 0; b < ArraySize(output_dims, 3); ++b) { for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); output_data[Offset(output_dims, c, x, y, b)] = - input1_data[SubscriptToIndex(desc1, c, x, y, b)] < - input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + F(scaled_input1_val, scaled_input2_val); } } } } } +#define TFLITE_COMPARISON_OP(name) \ + template <typename T> \ + inline void name(const T* input1_data, const Dims<4>& input1_dims, \ + const T* input2_data, const Dims<4>& input2_dims, \ + bool* output_data, const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template <typename T> \ + inline void name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ + BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, \ + input1_shift, input2_data, input2_dims, \ + input2_offset, input2_multiplier, \ + input2_shift, output_data, output_dims); \ + } \ + template <typename T> \ + inline void Broadcast##name( \ + const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \ + const Dims<4>& input2_dims, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ + BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template <typename T> \ + inline void Broadcast##name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ + BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, \ + input1_shift, input2_data, input2_dims, \ + input2_offset, input2_multiplier, \ + input2_shift, output_data, output_dims); \ + } +TFLITE_COMPARISON_OP(Greater); +TFLITE_COMPARISON_OP(GreaterEqual); +TFLITE_COMPARISON_OP(Less); +TFLITE_COMPARISON_OP(LessEqual); +#undef TFLITE_COMPARISON_OP + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index a6ea874546..40855891a6 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -80,7 +80,10 @@ TfLiteRegistration* Register_PRELU(); TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); +TfLiteRegistration* Register_GREATER(); +TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); +TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_NEG(); @@ -144,7 +147,10 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); + AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); + AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); + AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL()); AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 6253570fa2..21c2181377 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -672,7 +672,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_LESS: { + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: { break; } case BuiltinOperator_DELEGATE: { diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index b4c46917bf..e903af87b7 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -372,7 +372,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_MAXIMUM: case tflite::BuiltinOperator_MINIMUM: case tflite::BuiltinOperator_ARG_MAX: + case tflite::BuiltinOperator_GREATER: + case tflite::BuiltinOperator_GREATER_EQUAL: case tflite::BuiltinOperator_LESS: + case tflite::BuiltinOperator_LESS_EQUAL: case tflite::BuiltinOperator_NEG: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 84ff3b16bd..9409e76233 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -138,6 +138,9 @@ enum BuiltinOperator : byte { LESS = 58, NEG = 59, PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, } // Options for the builtin operators. @@ -183,7 +186,10 @@ union BuiltinOptions { DequantizeOptions, MaximumMinimumOptions, ArgMaxOptions, + GreaterOptions, + GreaterEqualOptions, LessOptions, + LessEqualOptions, NegOptions, } @@ -410,9 +416,18 @@ table ArgMaxOptions { output_type : TensorType; } +table GreaterOptions { +} + +table GreaterEqualOptions { +} + table LessOptions { } +table LessEqualOptions { +} + table NegOptions { } diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 8855e4ad58..ae3b33063e 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -154,9 +154,18 @@ struct MaximumMinimumOptionsT; struct ArgMaxOptions; struct ArgMaxOptionsT; +struct GreaterOptions; +struct GreaterOptionsT; + +struct GreaterEqualOptions; +struct GreaterEqualOptionsT; + struct LessOptions; struct LessOptionsT; +struct LessEqualOptions; +struct LessEqualOptionsT; + struct NegOptions; struct NegOptionsT; @@ -280,11 +289,14 @@ enum BuiltinOperator { BuiltinOperator_LESS = 58, BuiltinOperator_NEG = 59, BuiltinOperator_PADV2 = 60, + BuiltinOperator_GREATER = 61, + BuiltinOperator_GREATER_EQUAL = 62, + BuiltinOperator_LESS_EQUAL = 63, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_PADV2 + BuiltinOperator_MAX = BuiltinOperator_LESS_EQUAL }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[60] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[63] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -345,7 +357,10 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[60] { BuiltinOperator_MINIMUM, BuiltinOperator_LESS, BuiltinOperator_NEG, - BuiltinOperator_PADV2 + BuiltinOperator_PADV2, + BuiltinOperator_GREATER, + BuiltinOperator_GREATER_EQUAL, + BuiltinOperator_LESS_EQUAL }; return values; } @@ -413,6 +428,9 @@ inline const char **EnumNamesBuiltinOperator() { "LESS", "NEG", "PADV2", + "GREATER", + "GREATER_EQUAL", + "LESS_EQUAL", nullptr }; return names; @@ -466,13 +484,16 @@ enum BuiltinOptions { BuiltinOptions_DequantizeOptions = 39, BuiltinOptions_MaximumMinimumOptions = 40, BuiltinOptions_ArgMaxOptions = 41, - BuiltinOptions_LessOptions = 42, - BuiltinOptions_NegOptions = 43, + BuiltinOptions_GreaterOptions = 42, + BuiltinOptions_GreaterEqualOptions = 43, + BuiltinOptions_LessOptions = 44, + BuiltinOptions_LessEqualOptions = 45, + BuiltinOptions_NegOptions = 46, BuiltinOptions_MIN = BuiltinOptions_NONE, BuiltinOptions_MAX = BuiltinOptions_NegOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[44] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[47] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -516,7 +537,10 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[44] { BuiltinOptions_DequantizeOptions, BuiltinOptions_MaximumMinimumOptions, BuiltinOptions_ArgMaxOptions, + BuiltinOptions_GreaterOptions, + BuiltinOptions_GreaterEqualOptions, BuiltinOptions_LessOptions, + BuiltinOptions_LessEqualOptions, BuiltinOptions_NegOptions }; return values; @@ -566,7 +590,10 @@ inline const char **EnumNamesBuiltinOptions() { "DequantizeOptions", "MaximumMinimumOptions", "ArgMaxOptions", + "GreaterOptions", + "GreaterEqualOptions", "LessOptions", + "LessEqualOptions", "NegOptions", nullptr }; @@ -746,10 +773,22 @@ template<> struct BuiltinOptionsTraits<ArgMaxOptions> { static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions; }; +template<> struct BuiltinOptionsTraits<GreaterOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions; +}; + +template<> struct BuiltinOptionsTraits<GreaterEqualOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions; +}; + template<> struct BuiltinOptionsTraits<LessOptions> { static const BuiltinOptions enum_value = BuiltinOptions_LessOptions; }; +template<> struct BuiltinOptionsTraits<LessEqualOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions; +}; + template<> struct BuiltinOptionsTraits<NegOptions> { static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; }; @@ -1113,6 +1152,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_ArgMaxOptions ? reinterpret_cast<const ArgMaxOptionsT *>(value) : nullptr; } + GreaterOptionsT *AsGreaterOptions() { + return type == BuiltinOptions_GreaterOptions ? + reinterpret_cast<GreaterOptionsT *>(value) : nullptr; + } + const GreaterOptionsT *AsGreaterOptions() const { + return type == BuiltinOptions_GreaterOptions ? + reinterpret_cast<const GreaterOptionsT *>(value) : nullptr; + } + GreaterEqualOptionsT *AsGreaterEqualOptions() { + return type == BuiltinOptions_GreaterEqualOptions ? + reinterpret_cast<GreaterEqualOptionsT *>(value) : nullptr; + } + const GreaterEqualOptionsT *AsGreaterEqualOptions() const { + return type == BuiltinOptions_GreaterEqualOptions ? + reinterpret_cast<const GreaterEqualOptionsT *>(value) : nullptr; + } LessOptionsT *AsLessOptions() { return type == BuiltinOptions_LessOptions ? reinterpret_cast<LessOptionsT *>(value) : nullptr; @@ -1121,6 +1176,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_LessOptions ? reinterpret_cast<const LessOptionsT *>(value) : nullptr; } + LessEqualOptionsT *AsLessEqualOptions() { + return type == BuiltinOptions_LessEqualOptions ? + reinterpret_cast<LessEqualOptionsT *>(value) : nullptr; + } + const LessEqualOptionsT *AsLessEqualOptions() const { + return type == BuiltinOptions_LessEqualOptions ? + reinterpret_cast<const LessEqualOptionsT *>(value) : nullptr; + } NegOptionsT *AsNegOptions() { return type == BuiltinOptions_NegOptions ? reinterpret_cast<NegOptionsT *>(value) : nullptr; @@ -4056,6 +4119,86 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions( flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct GreaterOptionsT : public flatbuffers::NativeTable { + typedef GreaterOptions TableType; + GreaterOptionsT() { + } +}; + +struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GreaterOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GreaterOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<GreaterOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GreaterOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit GreaterOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GreaterOptionsBuilder &operator=(const GreaterOptionsBuilder &); + flatbuffers::Offset<GreaterOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GreaterOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + GreaterOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GreaterEqualOptionsT : public flatbuffers::NativeTable { + typedef GreaterEqualOptions TableType; + GreaterEqualOptionsT() { + } +}; + +struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GreaterEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GreaterEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<GreaterEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GreaterEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit GreaterEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GreaterEqualOptionsBuilder &operator=(const GreaterEqualOptionsBuilder &); + flatbuffers::Offset<GreaterEqualOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GreaterEqualOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + GreaterEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct LessOptionsT : public flatbuffers::NativeTable { typedef LessOptions TableType; LessOptionsT() { @@ -4096,6 +4239,46 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions( flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct LessEqualOptionsT : public flatbuffers::NativeTable { + typedef LessEqualOptions TableType; + LessEqualOptionsT() { + } +}; + +struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LessEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LessEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<LessEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LessEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LessEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LessEqualOptionsBuilder &operator=(const LessEqualOptionsBuilder &); + flatbuffers::Offset<LessEqualOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<LessEqualOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LessEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct NegOptionsT : public flatbuffers::NativeTable { typedef NegOptions TableType; NegOptionsT() { @@ -4376,9 +4559,18 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const ArgMaxOptions *builtin_options_as_ArgMaxOptions() const { return builtin_options_type() == BuiltinOptions_ArgMaxOptions ? static_cast<const ArgMaxOptions *>(builtin_options()) : nullptr; } + const GreaterOptions *builtin_options_as_GreaterOptions() const { + return builtin_options_type() == BuiltinOptions_GreaterOptions ? static_cast<const GreaterOptions *>(builtin_options()) : nullptr; + } + const GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const { + return builtin_options_type() == BuiltinOptions_GreaterEqualOptions ? static_cast<const GreaterEqualOptions *>(builtin_options()) : nullptr; + } const LessOptions *builtin_options_as_LessOptions() const { return builtin_options_type() == BuiltinOptions_LessOptions ? static_cast<const LessOptions *>(builtin_options()) : nullptr; } + const LessEqualOptions *builtin_options_as_LessEqualOptions() const { + return builtin_options_type() == BuiltinOptions_LessEqualOptions ? static_cast<const LessEqualOptions *>(builtin_options()) : nullptr; + } const NegOptions *builtin_options_as_NegOptions() const { return builtin_options_type() == BuiltinOptions_NegOptions ? static_cast<const NegOptions *>(builtin_options()) : nullptr; } @@ -4572,10 +4764,22 @@ template<> inline const ArgMaxOptions *Operator::builtin_options_as<ArgMaxOption return builtin_options_as_ArgMaxOptions(); } +template<> inline const GreaterOptions *Operator::builtin_options_as<GreaterOptions>() const { + return builtin_options_as_GreaterOptions(); +} + +template<> inline const GreaterEqualOptions *Operator::builtin_options_as<GreaterEqualOptions>() const { + return builtin_options_as_GreaterEqualOptions(); +} + template<> inline const LessOptions *Operator::builtin_options_as<LessOptions>() const { return builtin_options_as_LessOptions(); } +template<> inline const LessEqualOptions *Operator::builtin_options_as<LessEqualOptions>() const { + return builtin_options_as_LessEqualOptions(); +} + template<> inline const NegOptions *Operator::builtin_options_as<NegOptions>() const { return builtin_options_as_NegOptions(); } @@ -6206,6 +6410,52 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB _output_type); } +inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GreaterOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GreaterOptions::UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<GreaterOptions> GreaterOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGreaterOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGreaterOptions( + _fbb); +} + +inline GreaterEqualOptionsT *GreaterEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GreaterEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GreaterEqualOptions::UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<GreaterEqualOptions> GreaterEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGreaterEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGreaterEqualOptions( + _fbb); +} + inline LessOptionsT *LessOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LessOptionsT(); UnPackTo(_o, _resolver); @@ -6229,6 +6479,29 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBuffe _fbb); } +inline LessEqualOptionsT *LessEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LessEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LessEqualOptions::UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<LessEqualOptions> LessEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLessEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LessEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLessEqualOptions( + _fbb); +} + inline NegOptionsT *NegOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new NegOptionsT(); UnPackTo(_o, _resolver); @@ -6599,10 +6872,22 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<const GreaterOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } case BuiltinOptions_LessOptions: { auto ptr = reinterpret_cast<const LessOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<const LessEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } case BuiltinOptions_NegOptions: { auto ptr = reinterpret_cast<const NegOptions *>(obj); return verifier.VerifyTable(ptr); @@ -6789,10 +7074,22 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<const GreaterOptions *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj); + return ptr->UnPack(resolver); + } case BuiltinOptions_LessOptions: { auto ptr = reinterpret_cast<const LessOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<const LessEqualOptions *>(obj); + return ptr->UnPack(resolver); + } case BuiltinOptions_NegOptions: { auto ptr = reinterpret_cast<const NegOptions *>(obj); return ptr->UnPack(resolver); @@ -6967,10 +7264,22 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const ArgMaxOptionsT *>(value); return CreateArgMaxOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<const GreaterOptionsT *>(value); + return CreateGreaterOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<const GreaterEqualOptionsT *>(value); + return CreateGreaterEqualOptions(_fbb, ptr, _rehasher).Union(); + } case BuiltinOptions_LessOptions: { auto ptr = reinterpret_cast<const LessOptionsT *>(value); return CreateLessOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<const LessEqualOptionsT *>(value); + return CreateLessEqualOptions(_fbb, ptr, _rehasher).Union(); + } case BuiltinOptions_NegOptions: { auto ptr = reinterpret_cast<const NegOptionsT *>(value); return CreateNegOptions(_fbb, ptr, _rehasher).Union(); @@ -7145,10 +7454,22 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new ArgMaxOptionsT(*reinterpret_cast<ArgMaxOptionsT *>(u.value)); break; } + case BuiltinOptions_GreaterOptions: { + value = new GreaterOptionsT(*reinterpret_cast<GreaterOptionsT *>(u.value)); + break; + } + case BuiltinOptions_GreaterEqualOptions: { + value = new GreaterEqualOptionsT(*reinterpret_cast<GreaterEqualOptionsT *>(u.value)); + break; + } case BuiltinOptions_LessOptions: { value = new LessOptionsT(*reinterpret_cast<LessOptionsT *>(u.value)); break; } + case BuiltinOptions_LessEqualOptions: { + value = new LessEqualOptionsT(*reinterpret_cast<LessEqualOptionsT *>(u.value)); + break; + } case BuiltinOptions_NegOptions: { value = new NegOptionsT(*reinterpret_cast<NegOptionsT *>(u.value)); break; @@ -7365,11 +7686,26 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<GreaterOptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<GreaterEqualOptionsT *>(value); + delete ptr; + break; + } case BuiltinOptions_LessOptions: { auto ptr = reinterpret_cast<LessOptionsT *>(value); delete ptr; break; } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<LessEqualOptionsT *>(value); + delete ptr; + break; + } case BuiltinOptions_NegOptions: { auto ptr = reinterpret_cast<NegOptionsT *>(value); delete ptr; diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index ca1390fdeb..6749e63552 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -33,9 +33,12 @@ gen_zipped_test_files( "fused_batch_norm.zip", "gather.zip", "global_batch_norm.zip", + "greater.zip", + "greater_equal.zip", "l2_pool.zip", "l2norm.zip", "less.zip", + "less_equal.zip", "local_response_norm.zip", "log_softmax.zip", "max_pool.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 6fe0f491d0..7a658d43d3 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2055,6 +2055,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_greater_tests(zip_path): + """Make a set of tests to do greater.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the greater op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.greater(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_greater_equal_tests(zip_path): + """Make a set of tests to do greater_equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the greater_equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.greater_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_less_tests(zip_path): """Make a set of tests to do less.""" @@ -2089,6 +2157,40 @@ def make_less_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_less_equal_tests(zip_path): + """Make a set of tests to do less_equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the less_equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.less_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_floor_tests(zip_path): """Make a set of tests to do floor.""" diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 96681952c9..2ce14f3b38 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -258,9 +258,12 @@ INSTANTIATE_TESTS(fully_connected) INSTANTIATE_TESTS(fused_batch_norm) INSTANTIATE_TESTS(gather) INSTANTIATE_TESTS(global_batch_norm) +INSTANTIATE_TESTS(greater) +INSTANTIATE_TESTS(greater_equal) INSTANTIATE_TESTS(l2_pool) INSTANTIATE_TESTS(l2norm) INSTANTIATE_TESTS(less) +INSTANTIATE_TESTS(less_equal) INSTANTIATE_TESTS(local_response_norm) INSTANTIATE_TESTS(log_softmax) INSTANTIATE_TESTS(max_pool) diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 9e899cf977..53df1987b3 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1702,6 +1702,19 @@ void ConvertRandomUniformOperator(const Model& model, (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2); } +void ConvertComparisonOperator(const Model& model, const Operator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* comparison_op = tensorflow_graph->add_node(); + comparison_op->set_op(op_name); + comparison_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *comparison_op->add_input() = src_op.inputs[0]; + *comparison_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*comparison_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1893,6 +1906,14 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast<const RandomUniformOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowGreater) { + ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowLess) { + ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 9b0e232132..a081abea55 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -499,8 +499,8 @@ void ProcessTensorFlowReshapeOperator(Model* model, << op->outputs[0] << "\". Are your input shapes correct?"; } -void ProcessSimpleOperator(Model* model, Operator* op) { - const auto& input_array = model->GetArray(op->inputs[0]); +void ProcessSimpleOperator(Model* model, Operator* op, int input_index) { + const auto& input_array = model->GetArray(op->inputs[input_index]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -1499,7 +1499,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: - ProcessSimpleOperator(model, op); + ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: ProcessGatherOperator(model, static_cast<GatherOperator*>(op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 58e214b76b..a1ca7371c8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -55,7 +55,11 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell || type == OperatorType::kGather || - type == OperatorType::kTranspose || type == OperatorType::kMean; + type == OperatorType::kTranspose || type == OperatorType::kMean || + type == OperatorType::kTensorFlowGreater || + type == OperatorType::kTensorFlowGreaterEqual || + type == OperatorType::kTensorFlowLess || + type == OperatorType::kTensorFlowLessEqual; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { @@ -257,8 +261,7 @@ bool ChooseHardcodedQuantizationForOperatorOutput( IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); return true; } - if ((op.type == OperatorType::kLogistic) || - (op.type == OperatorType::kSoftmax)) { + if (op.type == OperatorType::kLogistic || op.type == OperatorType::kSoftmax) { // Logistic and Softmax have range: [0, 1]. // // For Logistic, 0.5 should be exactly representable, as implementations diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index df784a2a76..a008e63351 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -915,9 +915,16 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { "MAXIMUM", OperatorType::kTensorFlowMaximum)); ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>( "MINIMUM", OperatorType::kTensorFlowMinimum)); + ops.emplace_back(new SimpleOperator<TensorFlowGreaterOperator>( + "GREATER", OperatorType::kTensorFlowGreater)); + ops.emplace_back(new SimpleOperator<TensorFlowGreaterEqualOperator>( + "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual)); ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>( "LESS", OperatorType::kTensorFlowLess)); + ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>( + "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg)); + return ops; } } // namespace |