diff options
author | 2018-06-07 02:05:06 -0700 | |
---|---|---|
committer | 2018-06-07 02:07:45 -0700 | |
commit | c70b7128bfb9f0283c60bbec8fd7b0c12f741d95 (patch) | |
tree | 49a75161cb036b87817436d2bad9b79bfbb61425 /tensorflow | |
parent | c2368f875b53e9144a1803a3e67c5a61aa9c5862 (diff) |
Implementation of TensorFlowEqual and TensorFlowNotEqual.
PiperOrigin-RevId: 199602232
Diffstat (limited to 'tensorflow')
20 files changed, 666 insertions, 143 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 66d9a0dd44..13d9a463fb 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -204,6 +204,7 @@ def generated_test_models(): # "conv", "depthwiseconv", "div", + "equal", "exp", "expand_dims", "floor", @@ -226,6 +227,7 @@ def generated_test_models(): "minimum", "mul", "neg", + "not_equal", "pad", "padv2", # "prelu", diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index fc6fdd6eef..7b10b69f43 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -96,6 +96,8 @@ typedef enum { kTfLiteBuiltinSparseToDense = 68, kTfLiteBuiltinTile = 69, kTfLiteBuiltinExpandDims = 70, + kTfLiteBuiltinEqual = 71, + kTfLiteBuiltinNotEqual = 72, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 27e7d25bf1..19145281fa 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph: * [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide) * [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars) -* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater) -* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal) * [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity) -* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less) -* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal) * [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum) * [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum) * [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply) @@ -258,6 +254,19 @@ Options { } ``` +**EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + equal to the corresponding element of the second tensor. +} +``` + **EXP** ``` @@ -491,6 +500,19 @@ Options { } ``` +**NOT_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is not + equal to the corresponding element of the second tensor. +} +``` + **RELU** ``` diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 3b81062cd4..f678f48fa5 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -23,6 +23,7 @@ namespace tflite { namespace ops { namespace builtin { namespace comparisons { +namespace { constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; @@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { GetTensorData<type>(input2), GetTensorDims(input2), \ GetTensorData<bool>(output), GetTensorDims(output)); +TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Equal, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// TODO(renjieliu): Refactor the logic to avoid duplications. +TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); @@ -167,8 +219,22 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace } // namespace comparisons +TfLiteRegistration* Register_EQUAL() { + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval}; + return &r; +} + +TfLiteRegistration* Register_NOT_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::NotEqualEval}; + return &r; +} + TfLiteRegistration* Register_GREATER() { static TfLiteRegistration r = {nullptr, nullptr, comparisons::ComparisonPrepare, diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 835d238d36..bb02e1c812 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -21,18 +21,17 @@ limitations under the License. namespace tflite { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; -class GreaterOpModel : public SingleOpModel { +class ComparisonOpModel : public SingleOpModel { public: - GreaterOpModel(std::initializer_list<int> input1_shape, - std::initializer_list<int> input2_shape, - TensorType input_type) { + ComparisonOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type, BuiltinOperator op) { input1_ = AddInput(input_type); input2_ = AddInput(input_type); output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, - CreateGreaterOptions(builder_).Union()); + ConfigureBuiltinOp(op); BuildInterpreter({input1_shape, input2_shape}); } @@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel { int input1_; int input2_; int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_EqualOptions, + CreateEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_NOT_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_NotEqualOptions, + CreateNotEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER: { + SetBuiltinOp(op, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS: { + SetBuiltinOp(op, BuiltinOptions_LessOptions, + CreateLessOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } }; -TEST(ComparisonsTest, GreaterFloat) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); +TEST(ComparisonsTest, EqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_EQUAL); model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterInt) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcast) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcastTwoD) { - GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false, + false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class GreaterEqualOpModel : public SingleOpModel { - public: - GreaterEqualOpModel(std::initializer_list<int> input1_shape, - std::initializer_list<int> input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, - BuiltinOptions_GreaterEqualOptions, - CreateGreaterEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } +TEST(ComparisonsTest, NotEqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); - int input1() { return input1_; } - int input2() { return input2_; } + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} - std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } - std::vector<int> GetOutputShape() { return GetTensorShape(output_); } +TEST(ComparisonsTest, NotEqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); - private: - int input1_; - int input2_; - int output_; -}; + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, true, true, true, true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} + +TEST(ComparisonsTest, GreaterFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} TEST(ComparisonsTest, GreaterEqualFloat) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualInt) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcast) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { - GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessOpModel : public SingleOpModel { - public: - LessOpModel(std::initializer_list<int> input1_shape, - std::initializer_list<int> input2_shape, TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions, - CreateLessOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } - std::vector<int> GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; TEST(ComparisonsTest, LessFloat) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS); model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessInt) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcast) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcastTwoD) { - LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessEqualOpModel : public SingleOpModel { - public: - LessEqualOpModel(std::initializer_list<int> input1_shape, - std::initializer_list<int> input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, - CreateLessEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } - std::vector<int> GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; - TEST(ComparisonsTest, LessEqualFloat) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualInt) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcast) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcastTwoD) { - LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index ca5a20ad4f..0b644a1fa6 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3866,6 +3866,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } template <typename T> +inline bool EqualFn(T lhs, T rhs) { + return lhs == rhs; +} + +template <typename T> +inline bool NotEqualFn(T lhs, T rhs) { + return lhs != rhs; +} + +template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; } @@ -4028,6 +4038,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, input2_offset, input2_multiplier, \ input2_shift, output_data, output_dims); \ } +TFLITE_COMPARISON_OP(Equal); +TFLITE_COMPARISON_OP(NotEqual); TFLITE_COMPARISON_OP(Greater); TFLITE_COMPARISON_OP(GreaterEqual); TFLITE_COMPARISON_OP(Less); diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 184b02dcec..6c68bb2f31 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -93,6 +93,8 @@ TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); +TfLiteRegistration* Register_EQUAL(); +TfLiteRegistration* Register_NOT_EQUAL(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -168,6 +170,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 8d8d74adfb..d78b6eae90 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -689,6 +689,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_GREATER_EQUAL: case BuiltinOperator_LESS: case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_EQUAL: + case BuiltinOperator_NOT_EQUAL: case BuiltinOperator_SELECT: { break; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index d27ab0c033..605ce7d6fc 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -494,6 +494,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_TILE: case tflite::BuiltinOperator_EXPAND_DIMS: case tflite::BuiltinOperator_SPARSE_TO_DENSE: + case tflite::BuiltinOperator_EQUAL: + case tflite::BuiltinOperator_NOT_EQUAL: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 7dbb36c864..d12a96df1c 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -148,6 +148,8 @@ enum BuiltinOperator : byte { SPARSE_TO_DENSE = 68, TILE = 69, EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, } // Options for the builtin operators. @@ -204,6 +206,8 @@ union BuiltinOptions { SparseToDenseOptions, TileOptions, ExpandDimsOptions, + EqualOptions, + NotEqualOptions, } enum Padding : byte { SAME, VALID } @@ -478,6 +482,12 @@ table SparseToDenseOptions { validate_indices:bool; } +table EqualOptions { +} + +table NotEqualOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index b1beb39b28..8ddd2f1438 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -187,6 +187,12 @@ struct ExpandDimsOptionsT; struct SparseToDenseOptions; struct SparseToDenseOptionsT; +struct EqualOptions; +struct EqualOptionsT; + +struct NotEqualOptions; +struct NotEqualOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -317,11 +323,13 @@ enum BuiltinOperator { BuiltinOperator_SPARSE_TO_DENSE = 68, BuiltinOperator_TILE = 69, BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_EXPAND_DIMS + BuiltinOperator_MAX = BuiltinOperator_NOT_EQUAL }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[72] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -392,7 +400,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] { BuiltinOperator_TRANSPOSE_CONV, BuiltinOperator_SPARSE_TO_DENSE, BuiltinOperator_TILE, - BuiltinOperator_EXPAND_DIMS + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL }; return values; } @@ -470,6 +480,8 @@ inline const char **EnumNamesBuiltinOperator() { "SPARSE_TO_DENSE", "TILE", "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", nullptr }; return names; @@ -534,11 +546,13 @@ enum BuiltinOptions { BuiltinOptions_SparseToDenseOptions = 50, BuiltinOptions_TileOptions = 51, BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_ExpandDimsOptions + BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -592,7 +606,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] { BuiltinOptions_TransposeConvOptions, BuiltinOptions_SparseToDenseOptions, BuiltinOptions_TileOptions, - BuiltinOptions_ExpandDimsOptions + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions }; return values; } @@ -652,6 +668,8 @@ inline const char **EnumNamesBuiltinOptions() { "SparseToDenseOptions", "TileOptions", "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", nullptr }; return names; @@ -874,6 +892,14 @@ template<> struct BuiltinOptionsTraits<ExpandDimsOptions> { static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; }; +template<> struct BuiltinOptionsTraits<EqualOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsTraits<NotEqualOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1321,6 +1347,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_ExpandDimsOptions ? reinterpret_cast<const ExpandDimsOptionsT *>(value) : nullptr; } + EqualOptionsT *AsEqualOptions() { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast<EqualOptionsT *>(value) : nullptr; + } + const EqualOptionsT *AsEqualOptions() const { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast<const EqualOptionsT *>(value) : nullptr; + } + NotEqualOptionsT *AsNotEqualOptions() { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast<NotEqualOptionsT *>(value) : nullptr; + } + const NotEqualOptionsT *AsNotEqualOptions() const { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast<const NotEqualOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4781,6 +4823,86 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions( flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct EqualOptionsT : public flatbuffers::NativeTable { + typedef EqualOptions TableType; + EqualOptionsT() { + } +}; + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<EqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EqualOptionsBuilder &operator=(const EqualOptionsBuilder &); + flatbuffers::Offset<EqualOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<EqualOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<EqualOptions> CreateEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NotEqualOptionsT : public flatbuffers::NativeTable { + typedef NotEqualOptions TableType; + NotEqualOptionsT() { + } +}; + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NotEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<NotEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NotEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &); + flatbuffers::Offset<NotEqualOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<NotEqualOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -5068,6 +5190,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast<const ExpandDimsOptions *>(builtin_options()) : nullptr; } + const EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast<const EqualOptions *>(builtin_options()) : nullptr; + } + const NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast<const NotEqualOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -5302,6 +5430,14 @@ template<> inline const ExpandDimsOptions *Operator::builtin_options_as<ExpandDi return builtin_options_as_ExpandDimsOptions(); } +template<> inline const EqualOptions *Operator::builtin_options_as<EqualOptions>() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const NotEqualOptions *Operator::builtin_options_as<NotEqualOptions>() const { + return builtin_options_as_NotEqualOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -7196,6 +7332,52 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flat _validate_indices); } +inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<EqualOptions> EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateEqualOptions( + _fbb); +} + +inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NotEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<NotEqualOptions> NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNotEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNotEqualOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7590,6 +7772,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast<const EqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast<const NotEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7816,6 +8006,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast<const EqualOptions *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast<const NotEqualOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -8030,6 +8228,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const ExpandDimsOptionsT *>(value); return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast<const EqualOptionsT *>(value); + return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast<const NotEqualOptionsT *>(value); + return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8244,6 +8450,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new ExpandDimsOptionsT(*reinterpret_cast<ExpandDimsOptionsT *>(u.value)); break; } + case BuiltinOptions_EqualOptions: { + value = new EqualOptionsT(*reinterpret_cast<EqualOptionsT *>(u.value)); + break; + } + case BuiltinOptions_NotEqualOptions: { + value = new NotEqualOptionsT(*reinterpret_cast<NotEqualOptionsT *>(u.value)); + break; + } default: break; } @@ -8511,6 +8725,16 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast<EqualOptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast<NotEqualOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 351187f520..723b6ae057 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2165,6 +2165,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_equal_tests(zip_path): + """Make a set of tests to do equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_not_equal_tests(zip_path): + """Make a set of tests to do not equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the not euqal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.not_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_greater_tests(zip_path): """Make a set of tests to do greater.""" diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 99f0c81a1b..76ce1c5802 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1938,6 +1938,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast<const RandomUniformOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowEqual) { + ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowNotEqual) { + ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kTensorFlowGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 64096fb069..92d283ca2c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -60,6 +60,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowLessEqual: case OperatorType::kTensorFlowGreater: case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kTensorFlowEqual: + case OperatorType::kTensorFlowNotEqual: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index adb241da32..9e4262223e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1563,6 +1563,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowMaximum: case OperatorType::kTensorFlowMinimum: case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kTensorFlowEqual: + case OperatorType::kTensorFlowNotEqual: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index b9ebf66ff2..b13a88a9eb 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1908,6 +1908,12 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertSimpleOperator<SelectOperator, 3>(node, tf_import_flags, model); } else if (node.op() == "SparseToDense") { ConvertSparseToDenseOperator(node, tf_import_flags, model); + } else if (node.op() == "Equal") { + ConvertSimpleOperator<TensorFlowEqualOperator, 2>(node, tf_import_flags, + model); + } else if (node.op() == "NotEqual") { + ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>(node, tf_import_flags, + model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 1a4f87e363..81beb29372 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -136,6 +136,8 @@ enum class OperatorType { kReorderAxes, kSelect, kSparseToDense, + kTensorFlowEqual, + kTensorFlowNotEqual, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1358,6 +1360,22 @@ struct TensorFlowGreaterEqualOperator : Operator { : Operator(OperatorType::kTensorFlowGreaterEqual) {} }; +// TensorFlow Equal equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowEqualOperator : Operator { + TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {} +}; + +// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for +// details. +struct TensorFlowNotEqualOperator : Operator { + TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {} +}; + // Global max reduction: computes the max of all of entries in the input array. // Thus the output is "0-dimensional": it consists of a single scalar value. // diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index a8518adefc..8bfd76db6e 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1118,6 +1118,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice)); ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin)); + ops.emplace_back(new SimpleOperator<TensorFlowEqualOperator>( + "EQUAL", OperatorType::kTensorFlowEqual)); + ops.emplace_back(new SimpleOperator<TensorFlowNotEqualOperator>( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index d63c99a5f9..06bbe53516 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -119,6 +119,10 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect); CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice); CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin); + CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", + OperatorType::kTensorFlowEqual); + CheckSimpleOperator<TensorFlowNotEqualOperator>( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual); } TEST_F(OperatorTest, BuiltinAdd) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index fe7bed885d..5a82be3939 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -394,6 +394,8 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) HANDLE_OPERATORTYPENAME_CASE(SparseToDense) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE |