aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 02:05:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 02:07:45 -0700
commitc70b7128bfb9f0283c60bbec8fd7b0c12f741d95 (patch)
tree49a75161cb036b87817436d2bad9b79bfbb61425 /tensorflow
parentc2368f875b53e9144a1803a3e67c5a61aa9c5862 (diff)
Implementation of TensorFlowEqual and TensorFlowNotEqual.
PiperOrigin-RevId: 199602232
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl2
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md30
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc66
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc333
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h12
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc4
-rw-r--r--tensorflow/contrib/lite/model.cc2
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs10
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h236
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py68
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc6
-rw-r--r--tensorflow/contrib/lite/toco/model.h18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc2
20 files changed, 666 insertions, 143 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 66d9a0dd44..13d9a463fb 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -204,6 +204,7 @@ def generated_test_models():
# "conv",
"depthwiseconv",
"div",
+ "equal",
"exp",
"expand_dims",
"floor",
@@ -226,6 +227,7 @@ def generated_test_models():
"minimum",
"mul",
"neg",
+ "not_equal",
"pad",
"padv2",
# "prelu",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index fc6fdd6eef..7b10b69f43 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -96,6 +96,8 @@ typedef enum {
kTfLiteBuiltinSparseToDense = 68,
kTfLiteBuiltinTile = 69,
kTfLiteBuiltinExpandDims = 70,
+ kTfLiteBuiltinEqual = 71,
+ kTfLiteBuiltinNotEqual = 72,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 27e7d25bf1..19145281fa 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph:
* [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide)
* [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
* [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars)
-* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater)
-* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal)
* [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity)
-* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less)
-* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal)
* [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum)
* [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum)
* [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply)
@@ -258,6 +254,19 @@ Options {
}
```
+**EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ equal to the corresponding element of the second tensor.
+}
+```
+
**EXP**
```
@@ -491,6 +500,19 @@ Options {
}
```
+**NOT_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is not
+ equal to the corresponding element of the second tensor.
+}
+```
+
**RELU**
```
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 3b81062cd4..f678f48fa5 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -23,6 +23,7 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace comparisons {
+namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
@@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<type>(input2), GetTensorDims(input2), \
GetTensorData<bool>(output), GetTensorDims(output));
+TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// TODO(renjieliu): Refactor the logic to avoid duplications.
+TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
@@ -167,8 +219,22 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+} // namespace
} // namespace comparisons
+TfLiteRegistration* Register_EQUAL() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_NOT_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::NotEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_GREATER() {
static TfLiteRegistration r = {nullptr, nullptr,
comparisons::ComparisonPrepare,
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 835d238d36..bb02e1c812 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -21,18 +21,17 @@ limitations under the License.
namespace tflite {
namespace {
-using ::testing::ElementsAreArray;
+using ::testing::ElementsAre;
-class GreaterOpModel : public SingleOpModel {
+class ComparisonOpModel : public SingleOpModel {
public:
- GreaterOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
+ ComparisonOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type, BuiltinOperator op) {
input1_ = AddInput(input_type);
input2_ = AddInput(input_type);
output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
- CreateGreaterOptions(builder_).Union());
+ ConfigureBuiltinOp(op);
BuildInterpreter({input1_shape, input2_shape});
}
@@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel {
int input1_;
int input2_;
int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_EqualOptions,
+ CreateEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_NOT_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_NotEqualOptions,
+ CreateNotEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS: {
+ SetBuiltinOp(op, BuiltinOptions_LessOptions,
+ CreateLessOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
};
-TEST(ComparisonsTest, GreaterFloat) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+TEST(ComparisonsTest, EqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterInt) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcast) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcastTwoD) {
- GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false,
+ false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class GreaterEqualOpModel : public SingleOpModel {
- public:
- GreaterEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
- BuiltinOptions_GreaterEqualOptions,
- CreateGreaterEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
+TEST(ComparisonsTest, NotEqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
- int input1() { return input1_; }
- int input2() { return input2_; }
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+TEST(ComparisonsTest, NotEqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
- private:
- int input1_;
- int input2_;
- int output_;
-};
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, true, true, true, true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
+
+TEST(ComparisonsTest, GreaterFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
TEST(ComparisonsTest, GreaterEqualFloat) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualInt) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcast) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
- GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessOpModel : public SingleOpModel {
- public:
- LessOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape, TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions,
- CreateLessOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
TEST(ComparisonsTest, LessFloat) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessInt) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcast) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcastTwoD) {
- LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessEqualOpModel : public SingleOpModel {
- public:
- LessEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
- CreateLessEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
-
TEST(ComparisonsTest, LessEqualFloat) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualInt) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcast) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
- LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index ca5a20ad4f..0b644a1fa6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3866,6 +3866,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline bool EqualFn(T lhs, T rhs) {
+ return lhs == rhs;
+}
+
+template <typename T>
+inline bool NotEqualFn(T lhs, T rhs) {
+ return lhs != rhs;
+}
+
+template <typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs;
}
@@ -4028,6 +4038,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
input2_offset, input2_multiplier, \
input2_shift, output_data, output_dims); \
}
+TFLITE_COMPARISON_OP(Equal);
+TFLITE_COMPARISON_OP(NotEqual);
TFLITE_COMPARISON_OP(Greater);
TFLITE_COMPARISON_OP(GreaterEqual);
TFLITE_COMPARISON_OP(Less);
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 184b02dcec..6c68bb2f31 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -93,6 +93,8 @@ TfLiteRegistration* Register_SIN();
TfLiteRegistration* Register_TRANSPOSE_CONV();
TfLiteRegistration* Register_EXPAND_DIMS();
TfLiteRegistration* Register_SPARSE_TO_DENSE();
+TfLiteRegistration* Register_EQUAL();
+TfLiteRegistration* Register_NOT_EQUAL();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -168,6 +170,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
+ AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
+ AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 8d8d74adfb..d78b6eae90 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -689,6 +689,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_NOT_EQUAL:
case BuiltinOperator_SELECT: {
break;
}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index d27ab0c033..605ce7d6fc 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -494,6 +494,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_TILE:
case tflite::BuiltinOperator_EXPAND_DIMS:
case tflite::BuiltinOperator_SPARSE_TO_DENSE:
+ case tflite::BuiltinOperator_EQUAL:
+ case tflite::BuiltinOperator_NOT_EQUAL:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7dbb36c864..d12a96df1c 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -148,6 +148,8 @@ enum BuiltinOperator : byte {
SPARSE_TO_DENSE = 68,
TILE = 69,
EXPAND_DIMS = 70,
+ EQUAL = 71,
+ NOT_EQUAL = 72,
}
// Options for the builtin operators.
@@ -204,6 +206,8 @@ union BuiltinOptions {
SparseToDenseOptions,
TileOptions,
ExpandDimsOptions,
+ EqualOptions,
+ NotEqualOptions,
}
enum Padding : byte { SAME, VALID }
@@ -478,6 +482,12 @@ table SparseToDenseOptions {
validate_indices:bool;
}
+table EqualOptions {
+}
+
+table NotEqualOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index b1beb39b28..8ddd2f1438 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -187,6 +187,12 @@ struct ExpandDimsOptionsT;
struct SparseToDenseOptions;
struct SparseToDenseOptionsT;
+struct EqualOptions;
+struct EqualOptionsT;
+
+struct NotEqualOptions;
+struct NotEqualOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -317,11 +323,13 @@ enum BuiltinOperator {
BuiltinOperator_SPARSE_TO_DENSE = 68,
BuiltinOperator_TILE = 69,
BuiltinOperator_EXPAND_DIMS = 70,
+ BuiltinOperator_EQUAL = 71,
+ BuiltinOperator_NOT_EQUAL = 72,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_EXPAND_DIMS
+ BuiltinOperator_MAX = BuiltinOperator_NOT_EQUAL
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[72] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -392,7 +400,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] {
BuiltinOperator_TRANSPOSE_CONV,
BuiltinOperator_SPARSE_TO_DENSE,
BuiltinOperator_TILE,
- BuiltinOperator_EXPAND_DIMS
+ BuiltinOperator_EXPAND_DIMS,
+ BuiltinOperator_EQUAL,
+ BuiltinOperator_NOT_EQUAL
};
return values;
}
@@ -470,6 +480,8 @@ inline const char **EnumNamesBuiltinOperator() {
"SPARSE_TO_DENSE",
"TILE",
"EXPAND_DIMS",
+ "EQUAL",
+ "NOT_EQUAL",
nullptr
};
return names;
@@ -534,11 +546,13 @@ enum BuiltinOptions {
BuiltinOptions_SparseToDenseOptions = 50,
BuiltinOptions_TileOptions = 51,
BuiltinOptions_ExpandDimsOptions = 52,
+ BuiltinOptions_EqualOptions = 53,
+ BuiltinOptions_NotEqualOptions = 54,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_ExpandDimsOptions
+ BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -592,7 +606,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] {
BuiltinOptions_TransposeConvOptions,
BuiltinOptions_SparseToDenseOptions,
BuiltinOptions_TileOptions,
- BuiltinOptions_ExpandDimsOptions
+ BuiltinOptions_ExpandDimsOptions,
+ BuiltinOptions_EqualOptions,
+ BuiltinOptions_NotEqualOptions
};
return values;
}
@@ -652,6 +668,8 @@ inline const char **EnumNamesBuiltinOptions() {
"SparseToDenseOptions",
"TileOptions",
"ExpandDimsOptions",
+ "EqualOptions",
+ "NotEqualOptions",
nullptr
};
return names;
@@ -874,6 +892,14 @@ template<> struct BuiltinOptionsTraits<ExpandDimsOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions;
};
+template<> struct BuiltinOptionsTraits<EqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions;
+};
+
+template<> struct BuiltinOptionsTraits<NotEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1321,6 +1347,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ExpandDimsOptions ?
reinterpret_cast<const ExpandDimsOptionsT *>(value) : nullptr;
}
+ EqualOptionsT *AsEqualOptions() {
+ return type == BuiltinOptions_EqualOptions ?
+ reinterpret_cast<EqualOptionsT *>(value) : nullptr;
+ }
+ const EqualOptionsT *AsEqualOptions() const {
+ return type == BuiltinOptions_EqualOptions ?
+ reinterpret_cast<const EqualOptionsT *>(value) : nullptr;
+ }
+ NotEqualOptionsT *AsNotEqualOptions() {
+ return type == BuiltinOptions_NotEqualOptions ?
+ reinterpret_cast<NotEqualOptionsT *>(value) : nullptr;
+ }
+ const NotEqualOptionsT *AsNotEqualOptions() const {
+ return type == BuiltinOptions_NotEqualOptions ?
+ reinterpret_cast<const NotEqualOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4781,6 +4823,86 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(
flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct EqualOptionsT : public flatbuffers::NativeTable {
+ typedef EqualOptions TableType;
+ EqualOptionsT() {
+ }
+};
+
+struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef EqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<EqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct EqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ EqualOptionsBuilder &operator=(const EqualOptionsBuilder &);
+ flatbuffers::Offset<EqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<EqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ EqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct NotEqualOptionsT : public flatbuffers::NativeTable {
+ typedef NotEqualOptions TableType;
+ NotEqualOptionsT() {
+ }
+};
+
+struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef NotEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<NotEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct NotEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &);
+ flatbuffers::Offset<NotEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<NotEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ NotEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5068,6 +5190,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const {
return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast<const ExpandDimsOptions *>(builtin_options()) : nullptr;
}
+ const EqualOptions *builtin_options_as_EqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast<const EqualOptions *>(builtin_options()) : nullptr;
+ }
+ const NotEqualOptions *builtin_options_as_NotEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast<const NotEqualOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5302,6 +5430,14 @@ template<> inline const ExpandDimsOptions *Operator::builtin_options_as<ExpandDi
return builtin_options_as_ExpandDimsOptions();
}
+template<> inline const EqualOptions *Operator::builtin_options_as<EqualOptions>() const {
+ return builtin_options_as_EqualOptions();
+}
+
+template<> inline const NotEqualOptions *Operator::builtin_options_as<NotEqualOptions>() const {
+ return builtin_options_as_NotEqualOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7196,6 +7332,52 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flat
_validate_indices);
}
+inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new EqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<EqualOptions> EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateEqualOptions(
+ _fbb);
+}
+
+inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new NotEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<NotEqualOptions> NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateNotEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateNotEqualOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -7590,6 +7772,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -7816,6 +8006,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8030,6 +8228,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ExpandDimsOptionsT *>(value);
return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptionsT *>(value);
+ return CreateEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptionsT *>(value);
+ return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -8244,6 +8450,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ExpandDimsOptionsT(*reinterpret_cast<ExpandDimsOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_EqualOptions: {
+ value = new EqualOptionsT(*reinterpret_cast<EqualOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ value = new NotEqualOptionsT(*reinterpret_cast<NotEqualOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -8511,6 +8725,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<EqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<NotEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 351187f520..723b6ae057 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2165,6 +2165,74 @@ def make_arg_max_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_equal_tests(zip_path):
+ """Make a set of tests to do equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_not_equal_tests(zip_path):
+ """Make a set of tests to do not equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the not euqal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.not_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_greater_tests(zip_path):
"""Make a set of tests to do greater."""
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 99f0c81a1b..76ce1c5802 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1938,6 +1938,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowEqual) {
+ ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTensorFlowNotEqual) {
+ ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
} else if (src_op.type == OperatorType::kTensorFlowGreater) {
ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
} else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 64096fb069..92d283ca2c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -60,6 +60,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTensorFlowLessEqual:
case OperatorType::kTensorFlowGreater:
case OperatorType::kTensorFlowGreaterEqual:
+ case OperatorType::kTensorFlowEqual:
+ case OperatorType::kTensorFlowNotEqual:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index adb241da32..9e4262223e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1563,6 +1563,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTensorFlowMaximum:
case OperatorType::kTensorFlowMinimum:
case OperatorType::kTensorFlowGreaterEqual:
+ case OperatorType::kTensorFlowEqual:
+ case OperatorType::kTensorFlowNotEqual:
ProcessSimpleBinaryOperator(model, op);
break;
case OperatorType::kAddN:
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index b9ebf66ff2..b13a88a9eb 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1908,6 +1908,12 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertSimpleOperator<SelectOperator, 3>(node, tf_import_flags, model);
} else if (node.op() == "SparseToDense") {
ConvertSparseToDenseOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Equal") {
+ ConvertSimpleOperator<TensorFlowEqualOperator, 2>(node, tf_import_flags,
+ model);
+ } else if (node.op() == "NotEqual") {
+ ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>(node, tf_import_flags,
+ model);
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 1a4f87e363..81beb29372 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -136,6 +136,8 @@ enum class OperatorType {
kReorderAxes,
kSelect,
kSparseToDense,
+ kTensorFlowEqual,
+ kTensorFlowNotEqual,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1358,6 +1360,22 @@ struct TensorFlowGreaterEqualOperator : Operator {
: Operator(OperatorType::kTensorFlowGreaterEqual) {}
};
+// TensorFlow Equal equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowEqualOperator : Operator {
+ TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {}
+};
+
+// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for
+// details.
+struct TensorFlowNotEqualOperator : Operator {
+ TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {}
+};
+
// Global max reduction: computes the max of all of entries in the input array.
// Thus the output is "0-dimensional": it consists of a single scalar value.
//
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index a8518adefc..8bfd76db6e 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1118,6 +1118,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
+ ops.emplace_back(new SimpleOperator<TensorFlowEqualOperator>(
+ "EQUAL", OperatorType::kTensorFlowEqual));
+ ops.emplace_back(new SimpleOperator<TensorFlowNotEqualOperator>(
+ "NOT_EQUAL", OperatorType::kTensorFlowNotEqual));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index d63c99a5f9..06bbe53516 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -119,6 +119,10 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
+ CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL",
+ OperatorType::kTensorFlowEqual);
+ CheckSimpleOperator<TensorFlowNotEqualOperator>(
+ "NOT_EQUAL", OperatorType::kTensorFlowNotEqual);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index fe7bed885d..5a82be3939 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -394,6 +394,8 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
HANDLE_OPERATORTYPENAME_CASE(Select)
HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual)
+ HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE