diff options
author | 2018-06-20 11:48:15 -0700 | |
---|---|---|
committer | 2018-06-20 11:51:26 -0700 | |
commit | 4efefb90391b12c95339ed3b46a02b62ea5e195d (patch) | |
tree | bb3f9bb986b89287983ea8e7c35827993aad7206 /tensorflow/contrib/lite/kernels | |
parent | e51df5918020cdfada26022240091e5529f7da60 (diff) |
Implement TFLite Shape operator
PiperOrigin-RevId: 201389618
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/register.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/shape.cc | 93 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/shape_test.cc | 95 |
4 files changed, 205 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index bb5558443b..a77897a173 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -168,6 +168,7 @@ cc_library( "reshape.cc", "resize_bilinear.cc", "select.cc", + "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", @@ -994,6 +995,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "shape_test", + size = "small", + srcs = ["shape_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 07a7ee9115..67f6caea67 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -100,6 +100,7 @@ TfLiteRegistration* Register_EQUAL(); TfLiteRegistration* Register_NOT_EQUAL(); TfLiteRegistration* Register_SQRT(); TfLiteRegistration* Register_RSQRT(); +TfLiteRegistration* Register_SHAPE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -181,6 +182,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); + AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc new file mode 100644 index 0000000000..dbcd2ef004 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace shape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template <typename OutType> +void ExtractShape(const TfLiteTensor* input, OutType* output_data) { + for (int i = 0; i < NumDimensions(input); ++i) { + output_data[i] = SizeOfDimension(input, i); + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data); + switch (params->out_type) { + case kTfLiteInt32: + output->type = kTfLiteInt32; + break; + case kTfLiteInt64: + output->type = kTfLiteInt64; + break; + default: + context->ReportError(context, "Unknown shape output data type: %d", + params->out_type); + return kTfLiteError; + } + + // Shape always produces a 1-dimensional output tensor, where each output + // element is the length of the corresponding input tensor's dimension. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); + output_size->data[0] = NumDimensions(input); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TFLITE_DCHECK_EQ(NumDimensions(output), 1); + TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); + + switch (output->type) { + case kTfLiteInt32: + ExtractShape(input, GetTensorData<int32_t>(output)); + break; + case kTfLiteInt64: + ExtractShape(input, GetTensorData<int64_t>(output)); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace shape + +TfLiteRegistration* Register_SHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc new file mode 100644 index 0000000000..27b48f4e99 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <initializer_list> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template <typename T> +class ShapeOpModel : public SingleOpModel { + public: + ShapeOpModel(std::initializer_list<int> input_shape, TensorType input_type, + TensorType output_type) { + input_ = AddInput(input_type); + output_ = AddOutput(output_type); + SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions, + CreateShapeOptions(builder_, output_type).Union()); + BuildInterpreter({input_shape}); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int input() { return input_; } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector<T> GetOutput() { return ExtractVector<T>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ShapeOpTest, OutTypeInt) { + ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, OutTypeInt64) { + ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT64); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, ScalarTensor) { + ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_EQ(model.GetOutputSize(), 0); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(ShapeOpTest, EmptyTensor) { + ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |