diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-06-01 16:27:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-01 16:30:28 -0700 |
commit | b31498a054d55ce328a2820fd403af764c482500 (patch) | |
tree | 91b8513149a36ae042e2a1b51f9e284701bbdcec /tensorflow/contrib/lite/schema | |
parent | 73ec24e8b75ba4f73a06756502d8bf86b2a6828b (diff) |
Support 5-inputs LSTM kernel in TFLite (float only).
PiperOrigin-RevId: 198943559
Diffstat (limited to 'tensorflow/contrib/lite/schema')
-rw-r--r-- | tensorflow/contrib/lite/schema/schema.fbs | 12 | ||||
-rwxr-xr-x | tensorflow/contrib/lite/schema/schema_generated.h | 52 |
2 files changed, 60 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 7d76134e3d..7dbb36c864 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -315,11 +315,23 @@ table LocalResponseNormalizationOptions { beta:float; } +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { + // Parameters for LSTM version 1 or above. fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; } table ResizeBilinearOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 0a60fcd3d0..b1beb39b28 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -1428,6 +1428,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum LSTMKernelType { + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] { + static LSTMKernelType values[] = { + LSTMKernelType_FULL, + LSTMKernelType_BASIC + }; + return values; +} + +inline const char **EnumNamesLSTMKernelType() { + static const char *names[] = { + "FULL", + "BASIC", + nullptr + }; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { + const size_t index = static_cast<int>(e); + return EnumNamesLSTMKernelType()[index]; +} + enum CombinerType { CombinerType_SUM = 0, CombinerType_MEAN = 1, @@ -2865,10 +2894,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; + LSTMKernelType kernel_type; LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) { + proj_clip(0.0f), + kernel_type(LSTMKernelType_FULL) { } }; @@ -2877,7 +2908,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, - VT_PROJ_CLIP = 8 + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10 }; ActivationFunctionType fused_activation_function() const { return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -2888,11 +2920,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); } + LSTMKernelType kernel_type() const { + return static_cast<LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField<float>(verifier, VT_CELL_CLIP) && VerifyField<float>(verifier, VT_PROJ_CLIP) && + VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2912,6 +2948,9 @@ struct LSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement<float>(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } + void add_kernel_type(LSTMKernelType kernel_type) { + fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2928,10 +2967,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, float cell_clip = 0.0f, - float proj_clip = 0.0f) { + float proj_clip = 0.0f, + LSTMKernelType kernel_type = LSTMKernelType_FULL) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -6226,6 +6267,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = kernel_type(); _o->kernel_type = _e; }; } inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6239,11 +6281,13 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; + auto _kernel_type = _o->kernel_type; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, - _proj_clip); + _proj_clip, + _kernel_type); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { |