aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-11 21:18:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 21:22:09 -0700
commit6a21e1386e3e68cf752af861b9b1b950bda8a130 (patch)
tree0ccc9df1c589ffb059b707c58d8c7c094f6f9de0
parentcadd6b42bf6b01c2668420463b0986acd7fd9009 (diff)
Implementation of square.
PiperOrigin-RevId: 212577288
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py5
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc2
7 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 0210428026..e9c02cdbee 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -283,6 +283,7 @@ def generated_test_models():
"sparse_to_dense",
"split",
"sqrt",
+ "square",
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index 04995d70dd..8c624b3208 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
}
+TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, [](float f) { return f * f; });
+}
+
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
@@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
return &r;
}
+TfLiteRegistration* Register_SQUARE() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::SquareEval};
+ return &r;
+}
+
TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index b9d7d73c52..5dd89a0eae 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Square) {
+ ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
TEST(ElementWise, LogicalNot) {
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index c66959fdf4..14296d3a9f 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
TfLiteRegistration* Register_UNPACK();
TfLiteRegistration* Register_FLOOR_DIV();
+TfLiteRegistration* Register_SQUARE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
+ AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 812385e706..5d0895c72f 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path):
return _make_elementwise_tests(tf.rsqrt)(zip_path)
+def make_square_tests(zip_path):
+ """Make a set of tests to do square."""
+ return _make_elementwise_tests(tf.square)(zip_path)
+
+
def make_where_tests(zip_path):
"""Make a set of tests to do where."""
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index eb0f7c443a..5486012176 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1488,6 +1488,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"SQRT", OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
"RSQRT", OperatorType::kRsqrt));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
+ "SQUARE", OperatorType::kSquare));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 519a3a4e01..72e50a9aed 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot);
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
+ CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
+ OperatorType::kSquare);
}
TEST_F(OperatorTest, BuiltinAdd) {