diff options
author | Yu-Cheng Ling <ycling@google.com> | 2017-12-14 21:26:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-14 21:31:06 -0800 |
commit | dbcb1ffcca6a3c52e3c109a1739018350bc41925 (patch) | |
tree | aa4f440d1f8f4382cea7a1b4e4cfc623da2a76f0 /tensorflow | |
parent | f806269602219d5095265d036f294cc9a6260971 (diff) |
Support BatchToSpaceND in TFLite
The internal implementation only support 4D tensors for now.
The dimension has to be 1 batch + 2 spatial + 1 other.
The most common format within this restriction is NHWC.
Cropping is not supported by the internal implementation.
PiperOrigin-RevId: 179143332
Diffstat (limited to 'tensorflow')
19 files changed, 688 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 548864a1e9..5c6f3016b1 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -105,6 +105,17 @@ typedef struct { } TfLiteAddParams; typedef struct { + // Number of spatial dimensions. + // For now only NHWC is supported, and the value should always be 2. + int num_spatial_dimensions; + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int block_shape[2]; + int before_crops[2]; + int after_crops[2]; +} TfLiteBatchToSpaceNDParams; + +typedef struct { TfLiteFusedActivation activation; } TfLiteMulParams; diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 3908960c33..cc02cddb3d 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -77,6 +77,7 @@ cc_library( "activations.cc", "add.cc", "basic_rnn.cc", + "batch_to_space_nd.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", @@ -157,6 +158,18 @@ tf_cc_test( ) tf_cc_test( + name = "batch_to_space_nd_test", + size = "small", + srcs = ["batch_to_space_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( name = "concatenation_test", size = "small", srcs = ["concatenation_test.cc"], diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc new file mode 100644 index 0000000000..0eed680fdc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <string.h> +#include <vector> +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace batch_to_space_nd { + +// This file has two implementations of BatchToSpaceND. +enum KernelType { + kReference, + kGenericOptimized, +}; + +struct BatchToSpaceNDContext { + BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast<TfLiteBatchToSpaceNDParams*>(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteBatchToSpaceNDParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +// Currently, only 4D NHWC input/output op_context are supported. +// The 4D array need to have exactly 2 spatial dimensions. +// TODO(ycling): Support arbitrary dimension in BatchToSpaceND. +const int kInputDimensionNum = 4; +const int kOutputDimensionNum = 4; +const int kSpatialDimensionNum = 2; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. + TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + const TfLiteIntArray* input_size = op_context.input->dims; + const int* block_shape = op_context.params->block_shape; + + // Number of batch must be multiple of (block_shape[0] * block_shape[1]). + TF_LITE_ENSURE_EQ(context, + input_size->data[0] % (block_shape[0] * block_shape[1]), 0); + + const int output_batch_size = + input_size->data[0] / (block_shape[0] * block_shape[1]); + const int output_height = input_size->data[1] * block_shape[0]; + const int output_width = input_size->data[2] * block_shape[1]; + const int output_channel_size = input_size->data[3]; + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + output_size->data[0] = output_batch_size; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = output_channel_size; + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template <KernelType kernel_type> +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + BatchToSpaceNDContext op_context(context, node); + + int block_shape_dims_array[1] = {kSpatialDimensionNum}; + Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \ + GetTensorDims(op_context.input), \ + op_context.params->block_shape, block_shape_dims, \ + GetTensorData<scalar>(op_context.output), \ + GetTensorDims(op_context.output)) + switch (op_context.input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by BatchToSpace."); + return kTfLiteError; + } +#undef TF_LITE_BATCH_TO_SPACE_ND + return kTfLiteOk; +} + +} // namespace batch_to_space_nd + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval<batch_to_space_nd::kReference>}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND() { + return Register_BATCH_TO_SPACE_ND_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc new file mode 100644 index 0000000000..3ec4efbebc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BatchToSpaceNDOpModel : public SingleOpModel { + public: + BatchToSpaceNDOpModel(std::initializer_list<int> input_shape, + std::initializer_list<int> block_shape, + std::initializer_list<int> before_crops, + std::initializer_list<int> after_crops) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions( + builder_, builder_.CreateVector<int>(block_shape), + builder_.CreateVector<int>(before_crops), + builder_.CreateVector<int>(after_crops)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list<float> data) { + PopulateTensor<float>(input_, data); + } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(BatchToSpaceNDOpTest, SimpleTest) { + BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { + EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + "Cannot allocate tensors"); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 3d1edeef01..d4e7503f48 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -40,6 +40,7 @@ TfLiteRegistration* Register_HASHTABLE_LOOKUP(); TfLiteRegistration* Register_SOFTMAX(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_L2_NORMALIZATION(); TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); @@ -75,6 +76,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND()); AddBuiltin(BuiltinOperator_MUL, Register_MUL()); AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 4ef2c942c1..94e22b2659 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -518,6 +518,24 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_BATCH_TO_SPACE_ND: { + auto* params = MallocPOD<TfLiteBatchToSpaceNDParams>(); + if (auto* schema_params = + op->builtin_options_as_BatchToSpaceNDOptions()) { + const auto& block_shape = schema_params->block_shape(); + FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, + params->block_shape, error_reporter); + const auto& before_crops = schema_params->before_crops(); + FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops, + params->before_crops, error_reporter); + const auto& after_crops = schema_params->after_crops(); + FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops, + params->after_crops, error_reporter); + params->num_spatial_dimensions = block_shape->Length(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 6b93a70bff..5cb0afcea0 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -307,6 +307,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SKIP_GRAM: case tflite::BuiltinOperator_RELU1: case tflite::BuiltinOperator_GATHER: + case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 8b48543fc8..cc31e03dfc 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -107,6 +107,7 @@ enum BuiltinOperator : byte { PAD = 34, UNIDIRECTIONAL_SEQUENCE_RNN = 35, GATHER = 36, + BATCH_TO_SPACE_ND = 37, } // Options for the builtin operators. @@ -134,6 +135,7 @@ union BuiltinOptions { MulOptions, PadOptions, GatherOptions, + BatchToSpaceNDOptions, } enum Padding : byte { SAME, VALID } @@ -258,6 +260,12 @@ table ReshapeOptions { new_shape:[int]; } +table BatchToSpaceNDOptions { + block_shape:[int]; + before_crops:[int]; + after_crops:[int]; +} + table SkipGramOptions { ngram_size: int; max_skip_size: int; diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 7de205e1e4..aa169198fe 100644..100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -1,3 +1,4 @@ + /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -85,6 +86,9 @@ struct PadOptionsT; struct ReshapeOptions; struct ReshapeOptionsT; +struct BatchToSpaceNDOptions; +struct BatchToSpaceNDOptionsT; + struct SkipGramOptions; struct SkipGramOptionsT; @@ -176,11 +180,12 @@ enum BuiltinOperator { BuiltinOperator_PAD = 34, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35, BuiltinOperator_GATHER = 36, + BuiltinOperator_BATCH_TO_SPACE_ND = 37, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_GATHER + BuiltinOperator_MAX = BuiltinOperator_BATCH_TO_SPACE_ND }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[34] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[35] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -215,7 +220,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[34] { BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, BuiltinOperator_PAD, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOperator_GATHER}; + BuiltinOperator_GATHER, + BuiltinOperator_BATCH_TO_SPACE_ND}; return values; } @@ -257,6 +263,7 @@ inline const char **EnumNamesBuiltinOperator() { "PAD", "UNIDIRECTIONAL_SEQUENCE_RNN", "GATHER", + "BATCH_TO_SPACE_ND", nullptr}; return names; } @@ -291,11 +298,12 @@ enum BuiltinOptions { BuiltinOptions_MulOptions = 21, BuiltinOptions_PadOptions = 22, BuiltinOptions_GatherOptions = 23, + BuiltinOptions_BatchToSpaceNDOptions = 24, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_GatherOptions + BuiltinOptions_MAX = BuiltinOptions_BatchToSpaceNDOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[24] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[25] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -320,7 +328,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[24] { BuiltinOptions_EmbeddingLookupSparseOptions, BuiltinOptions_MulOptions, BuiltinOptions_PadOptions, - BuiltinOptions_GatherOptions}; + BuiltinOptions_GatherOptions, + BuiltinOptions_BatchToSpaceNDOptions}; return values; } @@ -349,6 +358,7 @@ inline const char **EnumNamesBuiltinOptions() { "MulOptions", "PadOptions", "GatherOptions", + "BatchToSpaceNDOptions", nullptr}; return names; } @@ -482,6 +492,11 @@ struct BuiltinOptionsTraits<GatherOptions> { static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; }; +template <> +struct BuiltinOptionsTraits<BatchToSpaceNDOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -759,6 +774,16 @@ struct BuiltinOptionsUnion { ? reinterpret_cast<const GatherOptionsT *>(value) : nullptr; } + BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() { + return type == BuiltinOptions_BatchToSpaceNDOptions + ? reinterpret_cast<BatchToSpaceNDOptionsT *>(value) + : nullptr; + } + const BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const { + return type == BuiltinOptions_BatchToSpaceNDOptions + ? reinterpret_cast<const BatchToSpaceNDOptionsT *>(value) + : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, @@ -2512,6 +2537,101 @@ flatbuffers::Offset<ReshapeOptions> CreateReshapeOptions( flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable { + typedef BatchToSpaceNDOptions TableType; + std::vector<int32_t> block_shape; + std::vector<int32_t> before_crops; + std::vector<int32_t> after_crops; + BatchToSpaceNDOptionsT() {} +}; + +struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS + : private flatbuffers::Table { + typedef BatchToSpaceNDOptionsT NativeTableType; + enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_CROPS = 6, VT_AFTER_CROPS = 8 }; + const flatbuffers::Vector<int32_t> *block_shape() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BLOCK_SHAPE); + } + const flatbuffers::Vector<int32_t> *before_crops() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEFORE_CROPS); + } + const flatbuffers::Vector<int32_t> *after_crops() const { + return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_AFTER_CROPS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BLOCK_SHAPE) && + verifier.Verify(block_shape()) && + VerifyOffset(verifier, VT_BEFORE_CROPS) && + verifier.Verify(before_crops()) && + VerifyOffset(verifier, VT_AFTER_CROPS) && + verifier.Verify(after_crops()) && verifier.EndTable(); + } + BatchToSpaceNDOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + BatchToSpaceNDOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<BatchToSpaceNDOptions> Pack( + flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BatchToSpaceNDOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_block_shape( + flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_shape) { + fbb_.AddOffset(BatchToSpaceNDOptions::VT_BLOCK_SHAPE, block_shape); + } + void add_before_crops( + flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_crops) { + fbb_.AddOffset(BatchToSpaceNDOptions::VT_BEFORE_CROPS, before_crops); + } + void add_after_crops( + flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_crops) { + fbb_.AddOffset(BatchToSpaceNDOptions::VT_AFTER_CROPS, after_crops); + } + explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BatchToSpaceNDOptionsBuilder &operator=(const BatchToSpaceNDOptionsBuilder &); + flatbuffers::Offset<BatchToSpaceNDOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BatchToSpaceNDOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_shape = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_crops = 0, + flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_crops = 0) { + BatchToSpaceNDOptionsBuilder builder_(_fbb); + builder_.add_after_crops(after_crops); + builder_.add_before_crops(before_crops); + builder_.add_block_shape(block_shape); + return builder_.Finish(); +} + +inline flatbuffers::Offset<BatchToSpaceNDOptions> +CreateBatchToSpaceNDOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<int32_t> *block_shape = nullptr, + const std::vector<int32_t> *before_crops = nullptr, + const std::vector<int32_t> *after_crops = nullptr) { + return tflite::CreateBatchToSpaceNDOptions( + _fbb, block_shape ? _fbb.CreateVector<int32_t>(*block_shape) : 0, + before_crops ? _fbb.CreateVector<int32_t>(*before_crops) : 0, + after_crops ? _fbb.CreateVector<int32_t>(*after_crops) : 0); +} + +flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions( + flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct SkipGramOptionsT : public flatbuffers::NativeTable { typedef SkipGramOptions TableType; int32_t ngram_size; @@ -3000,6 +3120,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { ? static_cast<const GatherOptions *>(builtin_options()) : nullptr; } + const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() + const { + return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions + ? static_cast<const BatchToSpaceNDOptions *>(builtin_options()) + : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -3162,6 +3288,12 @@ inline const GatherOptions *Operator::builtin_options_as<GatherOptions>() return builtin_options_as_GatherOptions(); } +template <> +inline const BatchToSpaceNDOptions * +Operator::builtin_options_as<BatchToSpaceNDOptions>() const { + return builtin_options_as_BatchToSpaceNDOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4614,6 +4746,74 @@ inline flatbuffers::Offset<ReshapeOptions> CreateReshapeOptions( return tflite::CreateReshapeOptions(_fbb, _new_shape); } +inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BatchToSpaceNDOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BatchToSpaceNDOptions::UnPackTo( + BatchToSpaceNDOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = block_shape(); + if (_e) { + _o->block_shape.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->block_shape[_i] = _e->Get(_i); + } + } + }; + { + auto _e = before_crops(); + if (_e) { + _o->before_crops.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->before_crops[_i] = _e->Get(_i); + } + } + }; + { + auto _e = after_crops(); + if (_e) { + _o->after_crops.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->after_crops[_i] = _e->Get(_i); + } + } + }; +} + +inline flatbuffers::Offset<BatchToSpaceNDOptions> BatchToSpaceNDOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBatchToSpaceNDOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions( + flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const BatchToSpaceNDOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _block_shape = + _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0; + auto _before_crops = + _o->before_crops.size() ? _fbb.CreateVector(_o->before_crops) : 0; + auto _after_crops = + _o->after_crops.size() ? _fbb.CreateVector(_o->after_crops) : 0; + return tflite::CreateBatchToSpaceNDOptions(_fbb, _block_shape, _before_crops, + _after_crops); +} + inline SkipGramOptionsT *SkipGramOptions::UnPack( const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SkipGramOptionsT(); @@ -5265,6 +5465,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast<const GatherOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast<const BatchToSpaceNDOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } @@ -5381,6 +5585,10 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast<const GatherOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast<const BatchToSpaceNDOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -5484,6 +5692,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast<const GatherOptionsT *>(value); return CreateGatherOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast<const BatchToSpaceNDOptionsT *>(value); + return CreateBatchToSpaceNDOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -5597,6 +5809,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) value = new GatherOptionsT(*reinterpret_cast<GatherOptionsT *>(u.value)); break; } + case BuiltinOptions_BatchToSpaceNDOptions: { + value = new BatchToSpaceNDOptionsT( + *reinterpret_cast<BatchToSpaceNDOptionsT *>(u.value)); + break; + } default: break; } @@ -5719,6 +5936,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast<BatchToSpaceNDOptionsT *>(value); + delete ptr; + break; + } default: break; } diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b63c0c058c..96800304e5 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -18,6 +18,7 @@ gen_zipped_test_files( files = [ "add.zip", "avg_pool.zip", + "batch_to_space_nd.zip", "concat.zip", "constant.zip", "control_dep.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 4c01fedb1e..02f59438cd 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -96,6 +96,10 @@ KNOWN_BUGS = { r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", # Gather doesn't support int64 indices. r"gather.*indices_dtype=int64": "XXXX", + # BatchToSpaceND doesn't support cropping. + r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634", + # BatchToSpaceND only supports 4D tensors. + r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", } @@ -1198,6 +1202,43 @@ def make_space_to_depth_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_batch_to_space_nd_tests(zip_path): + """Make a set of tests to do batch_to_space_nd.""" + + test_parameters = [ + { + "dtype": [tf.float32, tf.int64, tf.int32], + "input_shape": [[12, 2, 2, 1]], + "block_shape": [[1, 4], [2, 2], [3, 4]], + "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]], + }, + # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. + { + "dtype": [tf.float32], + "input_shape": [[8, 2, 2, 2, 1, 1]], + "block_shape": [[2, 2, 2]], + "crops": [[[0, 0], [0, 0], [0, 0]]], + }, + ] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"], + parameters["crops"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_l2_pool(input_tensor, ksize, strides, padding, data_format): """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" return tf.sqrt(tf.nn.avg_pool( @@ -1226,6 +1267,7 @@ def main(unused_args): dispatch = { "control_dep.zip": make_control_dep_tests, "add.zip": make_add_tests, + "batch_to_space_nd.zip": make_batch_to_space_nd_tests, "conv.zip": make_conv_tests, "constant.zip": make_constant_tests, "depthwiseconv.zip": make_depthwiseconv_tests, diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 29f0c68ba4..4c05979e24 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -241,6 +241,7 @@ TEST_P(OpsTest, RunStuff) { INSTANTIATE_TESTS(add) INSTANTIATE_TESTS(avg_pool) +INSTANTIATE_TESTS(batch_to_space_nd) INSTANTIATE_TESTS(concat) INSTANTIATE_TESTS(constant) INSTANTIATE_TESTS(control_dep) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 78c036fa77..7556a402f9 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -202,6 +202,7 @@ cc_library( "graph_transformations/remove_trivial_reshape.cc", "graph_transformations/remove_unused_op.cc", "graph_transformations/resolve_batch_normalization.cc", + "graph_transformations/resolve_batch_to_space_nd_attributes.cc", "graph_transformations/resolve_constant_binary.cc", "graph_transformations/resolve_constant_concatenation.cc", "graph_transformations/resolve_constant_fake_quant.cc", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index c1dc41170c..2eb244ee08 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -152,6 +152,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc new file mode 100644 index 0000000000..a4f198e92f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + if (op_it->get()->type != OperatorType::kBatchToSpaceND) return false; + + auto* op = static_cast<BatchToSpaceNDOperator*>(op_it->get()); + + // The attributes are resolved only when the 3 attributes (block_shape, + // before_crops, after_crops) are all constant. + if (!op->block_shape.empty()) { + return false; + } + + CHECK_EQ(op->inputs.size(), 3); + if (!IsConstantParameterArray(*model, op->inputs[1]) or + !IsConstantParameterArray(*model, op->inputs[2])) + return false; + + // Handling block_shape. + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + if (!block_shape_array.has_shape()) return false; + const std::vector<int>& block_shape_dims = block_shape_array.shape().dims(); + CHECK_EQ(block_shape_dims.size(), 1); + std::vector<int> block_shape_buffer = + block_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < block_shape_dims[0]; ++i) { + op->block_shape.push_back(block_shape_buffer[i]); + } + + // Handling crops. + const auto& crops_array = *model->arrays[op->inputs[2]]; + if (!crops_array.has_shape()) return false; + const std::vector<int>& crops_dims = crops_array.shape().dims(); + CHECK_EQ(crops_dims.size(), 2); + std::vector<int> crops_buffer = + crops_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < crops_dims[0]; ++i) { + op->before_crops.push_back(crops_buffer[i * 2]); + op->after_crops.push_back(crops_buffer[i * 2 + 1]); + } + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d155d2bb5c..7305f858da 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -1261,6 +1261,10 @@ struct SpaceToBatchNDOperator : Operator { // TensorFlow equivalent: BatchToSpaceND struct BatchToSpaceNDOperator : Operator { BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {} + + std::vector<int> block_shape; + std::vector<int> before_crops; + std::vector<int> after_crops; }; // Mean operator. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 7a68c6dbc9..ede6df88ab 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -130,6 +130,37 @@ class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions, } }; +class BatchToSpaceND + : public BuiltinOperator<BatchToSpaceNDOperator, + ::tflite::BatchToSpaceNDOptions, + ::tflite::BuiltinOptions_BatchToSpaceNDOptions> { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto block_shape = builder->CreateVector(op.block_shape); + auto before_crops = builder->CreateVector(op.before_crops); + auto after_crops = builder->CreateVector(op.after_crops); + return ::tflite::CreateBatchToSpaceNDOptions(*builder, block_shape, + before_crops, after_crops); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->block_shape.insert(op->block_shape.end(), + options.block_shape()->begin(), + options.block_shape()->end()); + op->before_crops.insert(op->before_crops.end(), + options.before_crops()->begin(), + options.before_crops()->end()); + op->after_crops.insert(op->after_crops.end(), + options.after_crops()->begin(), + options.after_crops()->end()); + } +}; + class Cast : public CustomOperator<CastOperator> { public: using CustomOperator::CustomOperator; @@ -571,6 +602,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool)); + ops.emplace_back( + new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND, + OperatorType::kBatchToSpaceND)); ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation)); ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index caecbd0325..735eea4ddc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -119,6 +119,19 @@ TEST_F(OperatorTest, BuiltinAdd) { output_toco_op->fused_activation_function); } +TEST_F(OperatorTest, BuiltinBatchToSpaceND) { + BatchToSpaceNDOperator op; + op.block_shape = {2, 2}; + op.before_crops = {1, 2}; + op.after_crops = {3, 4}; + + auto output_toco_op = SerializeAndDeserialize( + GetOperator("BATCH_TO_SPACE_ND", OperatorType::kBatchToSpaceND), op); + EXPECT_EQ(op.block_shape, output_toco_op->block_shape); + EXPECT_EQ(op.before_crops, output_toco_op->before_crops); + EXPECT_EQ(op.after_crops, output_toco_op->after_crops); +} + TEST_F(OperatorTest, CustomCast) { CastOperator op; op.src_data_type = ArrayDataType::kFloat; diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 7e50c2207f..d6652b7a41 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -78,6 +78,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new IdentifyRelu1); transformations->Add(new RemoveTrivialBinaryOperator); transformations->Add(new ReadFakeQuantMinMax); + transformations->Add(new ResolveBatchToSpaceNDAttributes); transformations->Add(new ResolvePadAttributes); transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); |