aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/schema
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-01 16:27:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 16:30:28 -0700
commitb31498a054d55ce328a2820fd403af764c482500 (patch)
tree91b8513149a36ae042e2a1b51f9e284701bbdcec /tensorflow/contrib/lite/schema
parent73ec24e8b75ba4f73a06756502d8bf86b2a6828b (diff)
Support 5-inputs LSTM kernel in TFLite (float only).
PiperOrigin-RevId: 198943559
Diffstat (limited to 'tensorflow/contrib/lite/schema')
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs12
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h52
2 files changed, 60 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7d76134e3d..7dbb36c864 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -315,11 +315,23 @@ table LocalResponseNormalizationOptions {
beta:float;
}
+enum LSTMKernelType : byte {
+ // Full LSTM kernel which supports peephole and projection.
+ FULL = 0,
+ // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
+ BASIC = 1,
+}
+
// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
table LSTMOptions {
+ // Parameters for LSTM version 1 or above.
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
+
+ // Parameters for LSTM version 2 or above.
+ // Basic kernel is only supported in version 2 or above.
+ kernel_type: LSTMKernelType = FULL;
}
table ResizeBilinearOptions {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 0a60fcd3d0..b1beb39b28 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -1428,6 +1428,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) {
return EnumNamesLSHProjectionType()[index];
}
+enum LSTMKernelType {
+ LSTMKernelType_FULL = 0,
+ LSTMKernelType_BASIC = 1,
+ LSTMKernelType_MIN = LSTMKernelType_FULL,
+ LSTMKernelType_MAX = LSTMKernelType_BASIC
+};
+
+inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
+ static LSTMKernelType values[] = {
+ LSTMKernelType_FULL,
+ LSTMKernelType_BASIC
+ };
+ return values;
+}
+
+inline const char **EnumNamesLSTMKernelType() {
+ static const char *names[] = {
+ "FULL",
+ "BASIC",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameLSTMKernelType(LSTMKernelType e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesLSTMKernelType()[index];
+}
+
enum CombinerType {
CombinerType_SUM = 0,
CombinerType_MEAN = 1,
@@ -2865,10 +2894,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
ActivationFunctionType fused_activation_function;
float cell_clip;
float proj_clip;
+ LSTMKernelType kernel_type;
LSTMOptionsT()
: fused_activation_function(ActivationFunctionType_NONE),
cell_clip(0.0f),
- proj_clip(0.0f) {
+ proj_clip(0.0f),
+ kernel_type(LSTMKernelType_FULL) {
}
};
@@ -2877,7 +2908,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
enum {
VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6,
- VT_PROJ_CLIP = 8
+ VT_PROJ_CLIP = 8,
+ VT_KERNEL_TYPE = 10
};
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@@ -2888,11 +2920,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
float proj_clip() const {
return GetField<float>(VT_PROJ_CLIP, 0.0f);
}
+ LSTMKernelType kernel_type() const {
+ return static_cast<LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
verifier.EndTable();
}
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2912,6 +2948,9 @@ struct LSTMOptionsBuilder {
void add_proj_clip(float proj_clip) {
fbb_.AddElement<float>(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
}
+ void add_kernel_type(LSTMKernelType kernel_type) {
+ fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
+ }
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2928,10 +2967,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::FlatBufferBuilder &_fbb,
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
float cell_clip = 0.0f,
- float proj_clip = 0.0f) {
+ float proj_clip = 0.0f,
+ LSTMKernelType kernel_type = LSTMKernelType_FULL) {
LSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip);
+ builder_.add_kernel_type(kernel_type);
builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish();
}
@@ -6226,6 +6267,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
{ auto _e = cell_clip(); _o->cell_clip = _e; };
{ auto _e = proj_clip(); _o->proj_clip = _e; };
+ { auto _e = kernel_type(); _o->kernel_type = _e; };
}
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -6239,11 +6281,13 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
auto _fused_activation_function = _o->fused_activation_function;
auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip;
+ auto _kernel_type = _o->kernel_type;
return tflite::CreateLSTMOptions(
_fbb,
_fused_activation_function,
_cell_clip,
- _proj_clip);
+ _proj_clip,
+ _kernel_type);
}
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {