diff options
author | 2018-01-19 12:08:47 -0800 | |
---|---|---|
committer | 2018-01-19 12:13:53 -0800 | |
commit | 83b751621439cc2b8a85450972414cf2f92a58cf (patch) | |
tree | 860f7f40a104388113121c79f7c6cb51cf7b1198 /tensorflow | |
parent | 39ae44e9822ed76639bae3ccf800b36039d1da55 (diff) |
Add support for time_major shape format to the sequential RNN Op in TF Lite.
This option, if set, changes the shape format of the inputs and outputs to
[max_time, batch_size, depth]. If false, it uses [batch_size, max_time, depth].
By default, it is set to false.
PiperOrigin-RevId: 182569507
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/lite/builtin_op_data.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc | 139 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | 102 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 12 | ||||
-rw-r--r-- | tensorflow/contrib/lite/schema/schema.fbs | 7 | ||||
-rwxr-xr-x | tensorflow/contrib/lite/schema/schema_generated.h | 176 |
6 files changed, 377 insertions, 64 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 0c333f9e8c..3b43a1fd5d 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -83,6 +83,11 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteRNNParams; +typedef struct { + bool time_major; + TfLiteFusedActivation activation; +} TfLiteSequenceRNNParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; typedef enum { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 85e09049ee..f5f1ec2cf3 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -34,7 +34,7 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kRecurrentWeightsTensor = 2; constexpr int kBiasTensor = 3; -constexpr int KHiddenStateTensor = 0; +constexpr int kHiddenStateTensor = 0; constexpr int kOutputTensor = 1; TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -51,8 +51,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check all the parameters of tensor match within themselves and match the // input configuration. - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int num_units = input_weights->dims->data[0]; TF_LITE_ASSERT_EQ(input->dims->data[2], input_weights->dims->data[1]); TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); @@ -60,7 +64,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[KHiddenStateTensor]]; + &context->tensors[node->outputs->data[kHiddenStateTensor]]; TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; // Resize state. @@ -75,8 +79,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3); - output_size_array->data[0] = batch_size; - output_size_array->data[1] = max_time; + output_size_array->data[0] = (time_major) ? max_time : batch_size; + output_size_array->data[1] = (time_major) ? batch_size : max_time; output_size_array->data[2] = num_units; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); @@ -84,8 +88,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +namespace { +void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int input_weights_stride, + int recurrent_weights_stride, TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch) { + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } +} +} // namespace + TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); + auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; TfLiteTensor* input_weights = @@ -94,61 +134,60 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; TfLiteTensor* hidden_state = - &context->tensors[node->outputs->data[KHiddenStateTensor]]; + &context->tensors[node->outputs->data[kHiddenStateTensor]]; TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; // Initialize the pointer bias. const float* bias_ptr = bias->data.f; - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[2]; const int input_weights_stride = input_weights->dims->data[1]; const int recurrent_weights_stride = recurrent_weights->dims->data[1]; - // For each batch - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to hidden state. - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - for (int s = 0; s < max_time; s++) { - // Initialize the pointer to input and output. - const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; - float* output_ptr_batch = - output->data.f + b * num_units * max_time + s * num_units; - - // Initialize input_weights and recurrent_weights. - const float* input_weights_ptr = input_weights->data.f; - const float* recurrent_weights_ptr = recurrent_weights->data.f; - - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; + if (time_major) { + // Unroll the sequence + for (int s = 0; s < max_time; s++) { + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size + b * input_size; + float* output_ptr_batch = + output->data.f + s * num_units * batch_size + b * num_units; + + RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, + bias_ptr, input_size, num_units, input_weights_stride, + recurrent_weights_stride, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = - (ActivationFunctor(params->activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } + } else { + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + output->data.f + b * num_units * max_time + s * num_units; + + RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, + bias_ptr, input_size, num_units, input_weights_stride, + recurrent_weights_stride, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } } diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index a1c1eda160..82c680ec3d 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Unit test for TFLite RNN op. +// Unit test for TFLite Sequential RNN op. #include <vector> #include <iomanip> @@ -125,7 +125,8 @@ static float rnn_golden_output[] = { class UnidirectionalRNNOpModel : public SingleOpModel { public: - UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size) + UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, + bool time_major) : batches_(batches), sequence_len_(sequence_len), units_(units), @@ -136,13 +137,22 @@ class UnidirectionalRNNOpModel : public SingleOpModel { bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_RNNOptions, - CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); - BuildInterpreter({{batches_, sequence_len_, input_size_}, - {units_, input_size_}, - {units_, units_}, - {units_}}); + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_SequenceRNNOptions, + CreateSequenceRNNOptions(builder_, time_major, + ActivationFunctionType_RELU) + .Union()); + if (time_major) { + BuildInterpreter({{sequence_len_, batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } else { + BuildInterpreter({{batches_, sequence_len_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } } void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } @@ -195,7 +205,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel { // TODO(mirkov): add another test which directly compares to TF once TOCO // supports the conversion from dynamic_rnn with BasicRNNCell. TEST(FullyConnectedOpTest, BlackBoxTest) { - UnidirectionalRNNOpModel rnn(2, 16, 16, 8); + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/false); rnn.SetWeights( {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, @@ -260,6 +271,77 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } +TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/true); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector<float> expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); + float* golden_batch_end = golden_batch_start + rnn.num_units(); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + } + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 86e613736d..4b0c853f77 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -339,7 +339,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>(); + if (auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + parse_activation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_RNN: { TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index f5251031b3..260a87c93b 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -151,6 +151,7 @@ union BuiltinOptions { SubOptions, DivOptions, SqueezeOptions, + SequenceRNNOptions, } enum Padding : byte { SAME, VALID } @@ -214,6 +215,12 @@ table RNNOptions { fused_activation_function:ActivationFunctionType; } +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; +} + // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { fused_activation_function:ActivationFunctionType; diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index a2ec8e40e9..fd98be8f70 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -48,6 +48,9 @@ struct SVDFOptionsT; struct RNNOptions; struct RNNOptionsT; +struct SequenceRNNOptions; +struct SequenceRNNOptionsT; + struct FullyConnectedOptions; struct FullyConnectedOptionsT; @@ -339,11 +342,12 @@ enum BuiltinOptions { BuiltinOptions_SubOptions = 28, BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, + BuiltinOptions_SequenceRNNOptions = 31, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SqueezeOptions + BuiltinOptions_MAX = BuiltinOptions_SequenceRNNOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -375,7 +379,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] { BuiltinOptions_MeanOptions, BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, - BuiltinOptions_SqueezeOptions}; + BuiltinOptions_SqueezeOptions, + BuiltinOptions_SequenceRNNOptions}; return values; } @@ -411,6 +416,7 @@ inline const char **EnumNamesBuiltinOptions() { "SubOptions", "DivOptions", "SqueezeOptions", + "SequenceRNNOptions", nullptr}; return names; } @@ -579,6 +585,11 @@ struct BuiltinOptionsTraits<SqueezeOptions> { static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; }; +template <> +struct BuiltinOptionsTraits<SequenceRNNOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -926,6 +937,16 @@ struct BuiltinOptionsUnion { ? reinterpret_cast<const SqueezeOptionsT *>(value) : nullptr; } + SequenceRNNOptionsT *AsSequenceRNNOptions() { + return type == BuiltinOptions_SequenceRNNOptions + ? reinterpret_cast<SequenceRNNOptionsT *>(value) + : nullptr; + } + const SequenceRNNOptionsT *AsSequenceRNNOptions() const { + return type == BuiltinOptions_SequenceRNNOptions + ? reinterpret_cast<const SequenceRNNOptionsT *>(value) + : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, @@ -1886,6 +1907,77 @@ flatbuffers::Offset<RNNOptions> CreateRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SequenceRNNOptionsT : public flatbuffers::NativeTable { + typedef SequenceRNNOptions TableType; + bool time_major; + ActivationFunctionType fused_activation_function; + SequenceRNNOptionsT() + : time_major(false), + fused_activation_function(ActivationFunctionType_NONE) {} +}; + +struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SequenceRNNOptionsT NativeTableType; + enum { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; + bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; } + ActivationFunctionType fused_activation_function() const { + return static_cast<ActivationFunctionType>( + GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + SequenceRNNOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + SequenceRNNOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<SequenceRNNOptions> Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SequenceRNNOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_time_major(bool time_major) { + fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_TIME_MAJOR, + static_cast<uint8_t>(time_major), 0); + } + void add_fused_activation_function( + ActivationFunctionType fused_activation_function) { + fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, + static_cast<int8_t>(fused_activation_function), 0); + } + explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SequenceRNNOptionsBuilder &operator=(const SequenceRNNOptionsBuilder &); + flatbuffers::Offset<SequenceRNNOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<SequenceRNNOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, + ActivationFunctionType fused_activation_function = + ActivationFunctionType_NONE) { + SequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_time_major(time_major); + return builder_.Finish(); +} + +flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; ActivationFunctionType fused_activation_function; @@ -3716,6 +3808,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { ? static_cast<const SqueezeOptions *>(builtin_options()) : nullptr; } + const SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const { + return builtin_options_type() == BuiltinOptions_SequenceRNNOptions + ? static_cast<const SequenceRNNOptions *>(builtin_options()) + : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -3917,6 +4014,12 @@ inline const SqueezeOptions *Operator::builtin_options_as<SqueezeOptions>() return builtin_options_as_SqueezeOptions(); } +template <> +inline const SequenceRNNOptions * +Operator::builtin_options_as<SequenceRNNOptions>() const { + return builtin_options_as_SequenceRNNOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4841,6 +4944,51 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions( return tflite::CreateRNNOptions(_fbb, _fused_activation_function); } +inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SequenceRNNOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SequenceRNNOptions::UnPackTo( + SequenceRNNOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = time_major(); + _o->time_major = _e; + } + { + auto _e = fused_activation_function(); + _o->fused_activation_function = _e; + } +} + +inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSequenceRNNOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const SequenceRNNOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _time_major = _o->time_major; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateSequenceRNNOptions(_fbb, _time_major, + _fused_activation_function); +} + inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack( const flatbuffers::resolver_function_t *_resolver) const { auto _o = new FullyConnectedOptionsT(); @@ -6397,6 +6545,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast<const SqueezeOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } @@ -6541,6 +6693,10 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast<const SqueezeOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -6672,6 +6828,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast<const SqueezeOptionsT *>(value); return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast<const SequenceRNNOptionsT *>(value); + return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -6817,6 +6977,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) new SqueezeOptionsT(*reinterpret_cast<SqueezeOptionsT *>(u.value)); break; } + case BuiltinOptions_SequenceRNNOptions: { + value = new SequenceRNNOptionsT( + *reinterpret_cast<SequenceRNNOptionsT *>(u.value)); + break; + } default: break; } @@ -6974,6 +7139,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast<SequenceRNNOptionsT *>(value); + delete ptr; + break; + } default: break; } |