diff options
author | 2018-01-18 13:31:04 -0800 | |
---|---|---|
committer | 2018-01-18 13:34:51 -0800 | |
commit | 45c47cabe7150386420b182a8026699ff704b8f4 (patch) | |
tree | 675246e399521be14e0f369c3507ad3d3bd5f673 | |
parent | 053470bc1a06b5f1cc65605bf21d48b3e92d6857 (diff) |
Supports Squeeze in Tf Lite.
PiperOrigin-RevId: 182429180
18 files changed, 501 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index b9d88128c2..0c333f9e8c 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -214,6 +214,13 @@ typedef struct { bool keep_dims; } TfLiteMeanParams; +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b30b28a6ad..7e9644f36c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -102,6 +102,7 @@ cc_library( "skip_gram.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "squeeze.cc", "sub.cc", "svdf.cc", "transpose.cc", @@ -492,6 +493,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "squeeze_test", + size = "small", + srcs = ["squeeze_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index fa63846020..45ad5f1890 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -56,6 +56,7 @@ TfLiteRegistration* Register_SPACE_TO_DEPTH(); TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_TRANSPOSE(); TfLiteRegistration* Register_MEAN(); +TfLiteRegistration* Register_SQUEEZE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -98,6 +99,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MEAN, Register_MEAN()); AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB()); + AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc new file mode 100644 index 0000000000..29447ab021 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/squeeze.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <string.h> +#include <vector> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace squeeze { + +struct SqueezeContext { + SqueezeContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteSqueezeParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SqueezeContext op_context(context, node); + int input_num_dims = NumDimensions(op_context.input); + int num_squeeze_dims = op_context.params->num_squeeze_dims; + + // Determines number of dimensions of output tensor after squeeze. + const TfLiteIntArray* input_dims = op_context.input->dims; + const int* squeeze_dims = op_context.params->squeeze_dims; + TF_LITE_ENSURE(context, input_num_dims <= 8); + bool should_squeeze[8] = {false}; + int num_squeezed_dims = 0; + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < input_num_dims; ++idx) { + if (input_dims->data[idx] == 1) { + should_squeeze[idx] = true; + ++num_squeezed_dims; + } + } + } else { + for (int idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims + : squeeze_dims[idx]; + TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims && + input_dims->data[current] == 1); + if (!should_squeeze[current]) ++num_squeezed_dims; + should_squeeze[current] = true; + } + } + // Sets output dimensions. + TfLiteIntArray* output_dims = + TfLiteIntArrayCreate(input_num_dims - num_squeezed_dims); + for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) { + if (!should_squeeze[in_idx]) { + output_dims->data[out_idx++] = input_dims->data[in_idx]; + } + } + return context->ResizeTensor(context, op_context.output, output_dims); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + SqueezeContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes); + memcpy(op_context.output->data.raw, op_context.input->data.raw, + op_context.input->bytes); + return kTfLiteOk; +} + +} // namespace squeeze + +TfLiteRegistration* Register_SQUEEZE() { + static TfLiteRegistration r = {nullptr, nullptr, squeeze::Prepare, + squeeze::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc new file mode 100644 index 0000000000..409227b626 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseSqueezeOpModel : public SingleOpModel { + public: + BaseSqueezeOpModel(const TensorData& input, const TensorData& output, + std::initializer_list<int> axis) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions, + CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + protected: + int input_; + int output_; +}; + +class FloatSqueezeOpModel : public BaseSqueezeOpModel { + public: + using BaseSqueezeOpModel::BaseSqueezeOpModel; + + void SetInput(std::initializer_list<float> data) { + PopulateTensor(input_, data); + } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } +}; + +TEST(FloatSqueezeOpTest, SqueezeAll) { + std::initializer_list<float> data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeSelectedAxis) { + std::initializer_list<float> data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {2}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) { + std::initializer_list<float> data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {-1, 0}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 09fc9b9613..86e613736d 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -596,6 +596,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_SQUEEZE: { + auto* params = MallocPOD<TfLiteSqueezeParams>(); + if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + const auto& squeeze_dims = schema_params->squeeze_dims(); + FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, + params->squeeze_dims, error_reporter); + params->num_squeeze_dims = squeeze_dims->Length(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 468c78dcce..b3602f799e 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -337,6 +337,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_MEAN: case tflite::BuiltinOperator_DIV: case tflite::BuiltinOperator_SUB: + case tflite::BuiltinOperator_SQUEEZE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 80c05f34cb..f5251031b3 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -116,6 +116,7 @@ enum BuiltinOperator : byte { MEAN = 40, SUB = 41, DIV = 42, + SQUEEZE = 43, } // Options for the builtin operators. @@ -149,6 +150,7 @@ union BuiltinOptions { MeanOptions, SubOptions, DivOptions, + SqueezeOptions, } enum Padding : byte { SAME, VALID } @@ -326,6 +328,10 @@ table MeanOptions { keep_dims: bool; } +table SqueezeOptions { + squeeze_dims:[int]; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 06c62604e4..a2ec8e40e9 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -114,6 +114,9 @@ struct TransposeOptionsT; struct MeanOptions; struct MeanOptionsT; +struct SqueezeOptions; +struct SqueezeOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -199,11 +202,12 @@ enum BuiltinOperator { BuiltinOperator_MEAN = 40, BuiltinOperator_SUB = 41, BuiltinOperator_DIV = 42, + BuiltinOperator_SQUEEZE = 43, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_DIV + BuiltinOperator_MAX = BuiltinOperator_SQUEEZE }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[40] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[41] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -244,7 +248,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[40] { BuiltinOperator_TRANSPOSE, BuiltinOperator_MEAN, BuiltinOperator_SUB, - BuiltinOperator_DIV}; + BuiltinOperator_DIV, + BuiltinOperator_SQUEEZE}; return values; } @@ -292,6 +297,7 @@ inline const char **EnumNamesBuiltinOperator() { "MEAN", "SUB", "DIV", + "SQUEEZE", nullptr}; return names; } @@ -332,11 +338,12 @@ enum BuiltinOptions { BuiltinOptions_MeanOptions = 27, BuiltinOptions_SubOptions = 28, BuiltinOptions_DivOptions = 29, + BuiltinOptions_SqueezeOptions = 30, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_DivOptions + BuiltinOptions_MAX = BuiltinOptions_SqueezeOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[30] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -367,7 +374,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[30] { BuiltinOptions_TransposeOptions, BuiltinOptions_MeanOptions, BuiltinOptions_SubOptions, - BuiltinOptions_DivOptions}; + BuiltinOptions_DivOptions, + BuiltinOptions_SqueezeOptions}; return values; } @@ -402,6 +410,7 @@ inline const char **EnumNamesBuiltinOptions() { "MeanOptions", "SubOptions", "DivOptions", + "SqueezeOptions", nullptr}; return names; } @@ -565,6 +574,11 @@ struct BuiltinOptionsTraits<DivOptions> { static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; }; +template <> +struct BuiltinOptionsTraits<SqueezeOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -902,6 +916,16 @@ struct BuiltinOptionsUnion { ? reinterpret_cast<const DivOptionsT *>(value) : nullptr; } + SqueezeOptionsT *AsSqueezeOptions() { + return type == BuiltinOptions_SqueezeOptions + ? reinterpret_cast<SqueezeOptionsT *>(value) + : nullptr; + } + const SqueezeOptionsT *AsSqueezeOptions() const { + return type == BuiltinOptions_SqueezeOptions + ? reinterpret_cast<const SqueezeOptionsT *>(value) + : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, @@ -3348,6 +3372,71 @@ flatbuffers::Offset<MeanOptions> CreateMeanOptions( flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SqueezeOptionsT : public flatbuffers::NativeTable { + typedef SqueezeOptions TableType; + std::vector<int32_t> squeeze_dims; + SqueezeOptionsT() {} +}; + +struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SqueezeOptionsT NativeTableType; + enum { VT_SQUEEZE_DIMS = 4 }; + const flatbuffers::Vector<int32_t> *squeeze_dims() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SQUEEZE_DIMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SQUEEZE_DIMS) && + verifier.Verify(squeeze_dims()) && verifier.EndTable(); + } + SqueezeOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SqueezeOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<SqueezeOptions> Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SqueezeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_squeeze_dims( + flatbuffers::Offset<flatbuffers::Vector<int32_t>> squeeze_dims) { + fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims); + } + explicit SqueezeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SqueezeOptionsBuilder &operator=(const SqueezeOptionsBuilder &); + flatbuffers::Offset<SqueezeOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<SqueezeOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> squeeze_dims = 0) { + SqueezeOptionsBuilder builder_(_fbb); + builder_.add_squeeze_dims(squeeze_dims); + return builder_.Finish(); +} + +inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *squeeze_dims = nullptr) { + return tflite::CreateSqueezeOptions( + _fbb, squeeze_dims ? _fbb.CreateVector<int32_t>(*squeeze_dims) : 0); +} + +flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -3622,6 +3711,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { ? static_cast<const DivOptions *>(builtin_options()) : nullptr; } + const SqueezeOptions *builtin_options_as_SqueezeOptions() const { + return builtin_options_type() == BuiltinOptions_SqueezeOptions + ? static_cast<const SqueezeOptions *>(builtin_options()) + : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -3817,6 +3911,12 @@ inline const DivOptions *Operator::builtin_options_as<DivOptions>() const { return builtin_options_as_DivOptions(); } +template <> +inline const SqueezeOptions *Operator::builtin_options_as<SqueezeOptions>() + const { + return builtin_options_as_SqueezeOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5744,6 +5844,51 @@ inline flatbuffers::Offset<MeanOptions> CreateMeanOptions( return tflite::CreateMeanOptions(_fbb, _axis, _keep_dims); } +inline SqueezeOptionsT *SqueezeOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SqueezeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SqueezeOptions::UnPackTo( + SqueezeOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = squeeze_dims(); + if (_e) { + _o->squeeze_dims.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->squeeze_dims[_i] = _e->Get(_i); + } + } + }; +} + +inline flatbuffers::Offset<SqueezeOptions> SqueezeOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSqueezeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<SqueezeOptions> CreateSqueezeOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SqueezeOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _squeeze_dims = + _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0; + return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims); +} + inline OperatorCodeT *OperatorCode::UnPack( const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); @@ -6248,6 +6393,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast<const DivOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast<const SqueezeOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } @@ -6388,6 +6537,10 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast<const DivOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast<const SqueezeOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -6515,6 +6668,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast<const DivOptionsT *>(value); return CreateDivOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast<const SqueezeOptionsT *>(value); + return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -6655,6 +6812,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) value = new DivOptionsT(*reinterpret_cast<DivOptionsT *>(u.value)); break; } + case BuiltinOptions_SqueezeOptions: { + value = + new SqueezeOptionsT(*reinterpret_cast<SqueezeOptionsT *>(u.value)); + break; + } default: break; } @@ -6807,6 +6969,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast<SqueezeOptionsT *>(value); + delete ptr; + break; + } default: break; } diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 48a7536d41..933da11353 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -45,6 +45,7 @@ gen_zipped_test_files( "softmax.zip", "space_to_batch_nd.zip", "space_to_depth.zip", + "squeeze.zip", "sub.zip", "transpose.zip", ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 9e17c2a370..29bf2cd7fc 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1346,6 +1346,40 @@ def make_transpose_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_squeeze_tests(zip_path): + """Make a set of tests to do squeeze.""" + + test_parameters = [{ + "dtype": [tf.int32, tf.float32, tf.int64], + "input_shape": [[1, 2, 1, 3, 1, 4, 1, 1]], + "axis": [ + None, [], [0, 2], [4, 7], [-1, 0, 2, 0, 7, -6], [1], [2, 3, 2], + [-1, -2, -4, -6, -8], [0, 2, 4, 6, 7], [7, 6, 4, 2, 0], [6, 6], + [0, 1, 2, 3, 4, 5, 6, 7], [-2, -3, 1, 0, 7, -5] + ], + }, { + "dtype": [tf.int32, tf.float32, tf.int64], + "input_shape": [[1]], + "axis": [None, [], [0], [-1]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + out = tf.squeeze(input_tensor, axis=parameters["axis"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_l2_pool(input_tensor, ksize, strides, padding, data_format): """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" return tf.sqrt(tf.nn.avg_pool( @@ -1403,6 +1437,7 @@ def main(unused_args): "space_to_depth.zip": make_space_to_depth_tests, "transpose.zip": make_transpose_tests, "mean.zip": make_mean_tests, + "squeeze.zip": make_squeeze_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 3c01302c43..c8a6e07abd 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -262,6 +262,7 @@ INSTANTIATE_TESTS(sub) INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(transpose) INSTANTIATE_TESTS(mean) +INSTANTIATE_TESTS(squeeze) } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index cea5b4e92d..967e304742 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -219,11 +219,11 @@ cc_library( "graph_transformations/resolve_reshape_attributes.cc", "graph_transformations/resolve_slice_attributes.cc", "graph_transformations/resolve_space_to_batch_nd_attributes.cc", + "graph_transformations/resolve_squeeze_attributes.cc", "graph_transformations/resolve_strided_slice_attributes.cc", "graph_transformations/resolve_tensorflow_concat.cc", "graph_transformations/resolve_tensorflow_matmul.cc", "graph_transformations/resolve_tensorflow_merge.cc", - "graph_transformations/resolve_tensorflow_squeeze.cc", "graph_transformations/resolve_tensorflow_switch.cc", "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 9300ab53a7..9ec9f92c90 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -146,7 +146,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) -DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze) +DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index 1d3f42b5ec..dd3e73635a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -25,15 +25,13 @@ limitations under the License. namespace toco { -bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) { - const auto squeeze_it = model->operators.begin() + op_index; - const auto* squeeze_op = squeeze_it->get(); +bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { + auto* squeeze_op = model->operators[op_index].get(); if (squeeze_op->type != OperatorType::kSqueeze) { return false; } - - CHECK_EQ(squeeze_op->inputs.size(), 1); - CHECK_EQ(squeeze_op->outputs.size(), 1); + DCHECK_EQ(squeeze_op->inputs.size(), 1); + DCHECK_EQ(squeeze_op->outputs.size(), 1); // If the output is consumed by a reshape op, it's a trivial squeeze. if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { @@ -47,7 +45,6 @@ bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) { return RemoveTrivialPassthroughOp(this, model, op_index); } } - return false; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 5b98b71155..0111e1ed92 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -584,6 +584,27 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions, } }; +class Squeeze + : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions, + ::tflite::BuiltinOptions_SqueezeOptions> { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto squeeze_dims = builder->CreateVector(op.squeeze_dims); + return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->squeeze_dims.insert(op->squeeze_dims.end(), + options.squeeze_dims()->begin(), + options.squeeze_dims()->end()); + } +}; + class Split : public CustomOperator<TensorFlowSplitOperator> { public: using CustomOperator::CustomOperator; @@ -754,6 +775,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back( + new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); // Custom Operators. ops.emplace_back(new Cast("CAST", OperatorType::kCast)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index debce63760..77c70847d1 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -389,6 +389,15 @@ TEST_F(OperatorTest, Transpose) { EXPECT_EQ(op.perm, output_toco_op->perm); } +TEST_F(OperatorTest, Squeeze) { + SqueezeOperator op; + op.squeeze_dims = {-2, -3, 4, 1, 4}; + + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SQUEEZE", OperatorType::kSqueeze), op); + EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 0bcf4596de..94b4d14696 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -74,7 +74,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantStridedSlice); transformations->Add(new ResolveConstantUnaryOperator); transformations->Add(new ResolveTensorFlowMerge); - transformations->Add(new ResolveTensorFlowSqueeze); + transformations->Add(new ResolveSqueezeAttributes); transformations->Add(new ResolveTensorFlowSwitch); transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); |