diff options
author | Jared Duke <jdduke@google.com> | 2018-07-26 10:53:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-26 10:56:43 -0700 |
commit | 6e658c0a5ca77677a954a34fb98f241c592c970d (patch) | |
tree | b645103887539af5232b3f70d80a2eb9b77ed63a /tensorflow/contrib/lite/schema | |
parent | 0a3155f7fbf56df5e81c7cbf35afd45173359635 (diff) |
Add one_hot op support to TFLite
PiperOrigin-RevId: 206185190
Diffstat (limited to 'tensorflow/contrib/lite/schema')
-rw-r--r-- | tensorflow/contrib/lite/schema/schema.fbs | 6 | ||||
-rwxr-xr-x | tensorflow/contrib/lite/schema/schema_generated.h | 141 |
2 files changed, 141 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index a285bf9919..8ed98ddaf4 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -166,6 +166,7 @@ enum BuiltinOperator : byte { REDUCE_MAX = 82, PACK = 83, LOGICAL_OR = 84, + ONE_HOT = 85, } // Options for the builtin operators. @@ -230,6 +231,7 @@ union BuiltinOptions { FakeQuantOptions, PackOptions, LogicalOrOptions, + OneHotOptions, } enum Padding : byte { SAME, VALID } @@ -549,6 +551,10 @@ table PackOptions { table LogicalOrOptions { } +table OneHotOptions { + axis:int; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 8c1d6d6a36..4402f89b85 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -211,6 +211,9 @@ struct PackOptionsT; struct LogicalOrOptions; struct LogicalOrOptionsT; +struct OneHotOptions; +struct OneHotOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -361,11 +364,12 @@ enum BuiltinOperator { BuiltinOperator_REDUCE_MAX = 82, BuiltinOperator_PACK = 83, BuiltinOperator_LOGICAL_OR = 84, + BuiltinOperator_ONE_HOT = 85, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR + BuiltinOperator_MAX = BuiltinOperator_ONE_HOT }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -450,7 +454,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] { BuiltinOperator_REDUCE_PROD, BuiltinOperator_REDUCE_MAX, BuiltinOperator_PACK, - BuiltinOperator_LOGICAL_OR + BuiltinOperator_LOGICAL_OR, + BuiltinOperator_ONE_HOT }; return values; } @@ -542,6 +547,7 @@ inline const char **EnumNamesBuiltinOperator() { "REDUCE_MAX", "PACK", "LOGICAL_OR", + "ONE_HOT", nullptr }; return names; @@ -614,11 +620,12 @@ enum BuiltinOptions { BuiltinOptions_FakeQuantOptions = 58, BuiltinOptions_PackOptions = 59, BuiltinOptions_LogicalOrOptions = 60, + BuiltinOptions_OneHotOptions = 61, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions + BuiltinOptions_MAX = BuiltinOptions_OneHotOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -680,7 +687,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] { BuiltinOptions_ArgMinOptions, BuiltinOptions_FakeQuantOptions, BuiltinOptions_PackOptions, - BuiltinOptions_LogicalOrOptions + BuiltinOptions_LogicalOrOptions, + BuiltinOptions_OneHotOptions }; return values; } @@ -748,6 +756,7 @@ inline const char **EnumNamesBuiltinOptions() { "FakeQuantOptions", "PackOptions", "LogicalOrOptions", + "OneHotOptions", nullptr }; return names; @@ -1002,6 +1011,10 @@ template<> struct BuiltinOptionsTraits<LogicalOrOptions> { static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; }; +template<> struct BuiltinOptionsTraits<OneHotOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1513,6 +1526,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_LogicalOrOptions ? reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr; } + OneHotOptionsT *AsOneHotOptions() { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast<OneHotOptionsT *>(value) : nullptr; + } + const OneHotOptionsT *AsOneHotOptions() const { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast<const OneHotOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -5452,6 +5473,60 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions( flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct OneHotOptionsT : public flatbuffers::NativeTable { + typedef OneHotOptions TableType; + int32_t axis; + OneHotOptionsT() + : axis(0) { + } +}; + +struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OneHotOptionsT NativeTableType; + enum { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField<int32_t>(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int32_t>(verifier, VT_AXIS) && + verifier.EndTable(); + } + OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<OneHotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OneHotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0); + } + explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &); + flatbuffers::Offset<OneHotOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<OneHotOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + OneHotOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -5765,6 +5840,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr; } + const OneHotOptions *builtin_options_as_OneHotOptions() const { + return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -6036,6 +6114,10 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as<LogicalOr return builtin_options_as_LogicalOrOptions(); } +template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOptions>() const { + return builtin_options_as_OneHotOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -8151,6 +8233,32 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers: _fbb); } +inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OneHotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; }; +} + +inline flatbuffers::Offset<OneHotOptions> OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOneHotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + return tflite::CreateOneHotOptions( + _fbb, + _axis); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -8580,6 +8688,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast<const OneHotOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -8838,6 +8950,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast<const OneHotOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -9084,6 +9200,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value); return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast<const OneHotOptionsT *>(value); + return CreateOneHotOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -9330,6 +9450,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value)); break; } + case BuiltinOptions_OneHotOptions: { + value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value)); + break; + } default: break; } @@ -9637,6 +9761,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast<OneHotOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; |