diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-11 21:18:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 21:22:09 -0700 |
commit | 6a21e1386e3e68cf752af861b9b1b950bda8a130 (patch) | |
tree | 0ccc9df1c589ffb059b707c58d8c7c094f6f9de0 | |
parent | cadd6b42bf6b01c2668420463b0986acd7fd9009 (diff) |
Implementation of square.
PiperOrigin-RevId: 212577288
-rw-r--r-- | tensorflow/contrib/lite/build_def.bzl | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/elementwise.cc | 12 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/elementwise_test.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/register.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator_test.cc | 2 |
7 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 0210428026..e9c02cdbee 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -283,6 +283,7 @@ def generated_test_models(): "sparse_to_dense", "split", "sqrt", + "square", "squeeze", "strided_slice", "strided_slice_1d_exhaustive", diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 04995d70dd..8c624b3208 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); } +TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) { + return EvalNumeric(context, node, [](float f) { return f * f; }); +} + TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { return EvalLogical(context, node, [](bool v) { return !v; }); } @@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() { return &r; } +TfLiteRegistration* Register_SQUARE() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare<elementwise::IsNumericSupportedType>, + elementwise::SquareEval}; + return &r; +} + TfLiteRegistration* Register_LOGICAL_NOT() { static TfLiteRegistration r = { /*init=*/nullptr, /*free=*/nullptr, diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index b9d7d73c52..5dd89a0eae 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Square) { + ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1}); + m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector<float>(m.output()), + ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + TEST(ElementWise, LogicalNot) { ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); m.PopulateTensor<bool>(m.input(), {true, false, true, false}); diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index c66959fdf4..14296d3a9f 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -118,6 +118,7 @@ TfLiteRegistration* Register_LOGICAL_AND(); TfLiteRegistration* Register_LOGICAL_NOT(); TfLiteRegistration* Register_UNPACK(); TfLiteRegistration* Register_FLOOR_DIV(); +TfLiteRegistration* Register_SQUARE(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -243,6 +244,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV()); + AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 812385e706..5d0895c72f 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2882,6 +2882,11 @@ def make_rsqrt_tests(zip_path): return _make_elementwise_tests(tf.rsqrt)(zip_path) +def make_square_tests(zip_path): + """Make a set of tests to do square.""" + return _make_elementwise_tests(tf.square)(zip_path) + + def make_where_tests(zip_path): """Make a set of tests to do where.""" diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index eb0f7c443a..5486012176 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1488,6 +1488,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( "SQRT", OperatorType::kSqrt)); ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>( "RSQRT", OperatorType::kRsqrt)); + ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>( + "SQUARE", OperatorType::kSquare)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 519a3a4e01..72e50a9aed 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -144,6 +144,8 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT", OperatorType::kLogicalNot); CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv); + CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE", + OperatorType::kSquare); } TEST_F(OperatorTest, BuiltinAdd) { |