diff options
author | 2018-02-05 18:45:13 -0800 | |
---|---|---|
committer | 2018-02-05 18:48:59 -0800 | |
commit | 1a0b637df8d082301118dd0f85ec63704f862aeb (patch) | |
tree | 9638dd8ae3174a20c456dd7d1569eddeea7330ba /tensorflow/contrib | |
parent | 09138338ed81b56aeaff88b9ee5a11c3f6d77b70 (diff) |
Export align_corners to TF Lite
PiperOrigin-RevId: 184622482
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/builtin_op_data.h | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/resize_bilinear.cc | 12 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/schema/schema.fbs | 3 | ||||
-rwxr-xr-x | tensorflow/contrib/lite/schema/schema_generated.h | 40 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/generated_examples_zip_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 22 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator_test.cc | 10 |
8 files changed, 65 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index a1037a525c..5dbeadd165 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -151,6 +151,7 @@ typedef struct { } TfLiteLSTMParams; typedef struct { + bool align_corners; } TfLiteResizeBilinearParams; typedef struct { diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 4a2101f246..c5d60cae3a 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -75,6 +75,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* size = GetInput(context, node, kSizeTensor); @@ -86,10 +89,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \ - GetTensorData<int32>(size), GetTensorDims(size), \ - GetTensorData<float>(output), GetTensorDims(output)) +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \ + GetTensorData<int32>(size), GetTensorDims(size), \ + GetTensorData<float>(output), GetTensorDims(output), \ + params->align_corners) if (kernel_type == kReference) { TF_LITE_RESIZE_BILINEAR(reference_ops); diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index b36bfcef84..14b6709964 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -469,6 +469,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, auto* params = MallocPOD<TfLiteResizeBilinearParams>(); if (auto* schema_params = op->builtin_options_as_ResizeBilinearOptions()) { + params->align_corners = schema_params->align_corners(); } builtin_data = reinterpret_cast<void*>(params); break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index c0b220e872..36cc2724eb 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -273,6 +273,9 @@ table LSTMOptions { } table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; } // A call operation options diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 29f3a17be7..e2ac0b9d1e 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -2626,28 +2626,36 @@ flatbuffers::Offset<LSTMOptions> CreateLSTMOptions( struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { typedef ResizeBilinearOptions TableType; - ResizeBilinearOptionsT() {} + bool align_corners; + ResizeBilinearOptionsT() + : align_corners(false) { + } }; -struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ResizeBilinearOptionsT NativeTableType; + enum { + VT_ALIGN_CORNERS = 8 + }; + bool align_corners() const { + return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS) && + verifier.EndTable(); } - ResizeBilinearOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ResizeBilinearOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset<ResizeBilinearOptions> Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<ResizeBilinearOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ResizeBilinearOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; + void add_align_corners(bool align_corners) { + fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast<uint8_t>(align_corners), 0); + } explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2661,14 +2669,14 @@ struct ResizeBilinearOptionsBuilder { }; inline flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb) { + flatbuffers::FlatBufferBuilder &_fbb, + bool align_corners = false) { ResizeBilinearOptionsBuilder builder_(_fbb); + builder_.add_align_corners(align_corners); return builder_.Finish(); } -flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct CallOptionsT : public flatbuffers::NativeTable { typedef CallOptions TableType; diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index e8b425a592..5ea3e21f6a 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -89,9 +89,6 @@ std::map<string, string> kBrokenTests = { // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"}, - {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"}, // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 461494fd99..04aaedd59d 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -540,6 +540,24 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions, } }; +class ResizeBilinear + : public BuiltinOperator<ResizeBilinearOperator, + ::tflite::ResizeBilinearOptions, + ::tflite::BuiltinOptions_ResizeBilinearOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->align_corners = options.align_corners(); + } +}; + class Squeeze : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions, ::tflite::BuiltinOptions_SqueezeOptions> { @@ -755,6 +773,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, + OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, @@ -787,8 +807,6 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1)); ops.emplace_back( new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6)); - ops.emplace_back(new SimpleOperator<ResizeBilinearOperator>( - "RESIZE_BILINEAR", OperatorType::kResizeBilinear)); ops.emplace_back(new SimpleOperator<LogisticOperator>( "LOGISTIC", OperatorType::kLogistic)); ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 6daa296282..796534be53 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -104,8 +104,6 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu); CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1); CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6); - CheckSimpleOperator<ResizeBilinearOperator>("RESIZE_BILINEAR", - OperatorType::kResizeBilinear); CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic); CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh); } @@ -331,6 +329,14 @@ TEST_F(OperatorTest, BuiltinMul) { output_toco_op->fused_activation_function); } +TEST_F(OperatorTest, ResizeBilinear) { + ResizeBilinearOperator op; + op.align_corners = true; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op); + EXPECT_EQ(op.align_corners, output_toco_op->align_corners); +} + TEST_F(OperatorTest, Svdf) { SvdfOperator op; op.fused_activation_function = FusedActivationFunctionType::kRelu; |