aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-26 19:35:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 19:38:26 -0700
commite41e70ed9827b81a07c42f68def80f3f61b70375 (patch)
tree04fadd1f072d787c13f145ebf9c9f4a17008e9ab /tensorflow/contrib/lite/kernels
parent84b3322931fd6fd73ce4ab250a1bd3cdd6e138f6 (diff)
Implement floor operator
PiperOrigin-RevId: 194490433
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD14
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc58
-rw-r--r--tensorflow/contrib/lite/kernels/floor_test.cc83
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
4 files changed, 157 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 80cefe83b2..689f9bfa71 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -145,6 +145,7 @@ cc_library(
"embedding_lookup.cc",
"embedding_lookup_sparse.cc",
"exp.cc",
+ "floor.cc",
"fully_connected.cc",
"gather.cc",
"hashtable_lookup.cc",
@@ -438,6 +439,19 @@ tf_cc_test(
)
tf_cc_test(
+ name = "floor_test",
+ size = "small",
+ srcs = ["floor_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "unidirectional_sequence_lstm_test",
size = "small",
srcs = ["unidirectional_sequence_lstm_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
new file mode 100644
index 0000000000..4b4395f711
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace floor {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ output->type = input->type;
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(output), GetTensorDims(output));
+ return kTfLiteOk;
+}
+} // namespace floor
+
+TfLiteRegistration* Register_FLOOR() {
+ static TfLiteRegistration r = {/*init=*/nullptr,
+ /*free=*/nullptr, floor::Prepare, floor::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/floor_test.cc b/tensorflow/contrib/lite/kernels/floor_test.cc
new file mode 100644
index 0000000000..b71e0400b6
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FloorOpModel : public SingleOpModel {
+ public:
+ FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
+ BuildInterpreter({
+ input_shape,
+ });
+ }
+
+ int input() { return input_; }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(FloorOpTest, SingleDim) {
+ FloorOpModel model({2}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input(), {8.5, 0.0});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+TEST(FloorOpTest, MultiDims) {
+ FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
+ model.PopulateTensor<float>(model.input(), {
+ 0.0001,
+ 8.0001,
+ 0.9999,
+ 9.9999,
+ 0.5,
+ -0.0001,
+ -8.0001,
+ -0.9999,
+ -9.9999,
+ -0.5,
+ });
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index b07e7b6ff3..f91d188ffa 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -80,6 +80,7 @@ TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
TfLiteRegistration* Register_LESS();
+TfLiteRegistration* Register_FLOOR();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -141,6 +142,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
+ AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.