diff options
author | 2018-04-26 19:35:10 -0700 | |
---|---|---|
committer | 2018-04-26 19:38:26 -0700 | |
commit | e41e70ed9827b81a07c42f68def80f3f61b70375 (patch) | |
tree | 04fadd1f072d787c13f145ebf9c9f4a17008e9ab /tensorflow/contrib/lite/kernels | |
parent | 84b3322931fd6fd73ce4ab250a1bd3cdd6e138f6 (diff) |
Implement floor operator
PiperOrigin-RevId: 194490433
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/floor.cc | 58 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/floor_test.cc | 83 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/register.cc | 2 |
4 files changed, 157 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 80cefe83b2..689f9bfa71 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -145,6 +145,7 @@ cc_library( "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", + "floor.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", @@ -438,6 +439,19 @@ tf_cc_test( ) tf_cc_test( + name = "floor_test", + size = "small", + srcs = ["floor_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( name = "unidirectional_sequence_lstm_test", size = "small", srcs = ["unidirectional_sequence_lstm_test.cc"], diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc new file mode 100644 index 0000000000..4b4395f711 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/floor.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace floor { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + output->type = input->type; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input), + GetTensorData<float>(output), GetTensorDims(output)); + return kTfLiteOk; +} +} // namespace floor + +TfLiteRegistration* Register_FLOOR() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, floor::Prepare, floor::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/floor_test.cc b/tensorflow/contrib/lite/kernels/floor_test.cc new file mode 100644 index 0000000000..b71e0400b6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/floor_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FloorOpModel : public SingleOpModel { + public: + FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0); + BuildInterpreter({ + input_shape, + }); + } + + int input() { return input_; } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(FloorOpTest, SingleDim) { + FloorOpModel model({2}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input(), {8.5, 0.0}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(FloorOpTest, MultiDims) { + FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input(), { + 0.0001, + 8.0001, + 0.9999, + 9.9999, + 0.5, + -0.0001, + -8.0001, + -0.9999, + -9.9999, + -0.5, + }); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index b07e7b6ff3..f91d188ffa 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -80,6 +80,7 @@ TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); TfLiteRegistration* Register_LESS(); +TfLiteRegistration* Register_FLOOR(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -141,6 +142,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); + AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. |