aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-19 12:08:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-19 12:13:53 -0800
commit83b751621439cc2b8a85450972414cf2f92a58cf (patch)
tree860f7f40a104388113121c79f7c6cb51cf7b1198
parent39ae44e9822ed76639bae3ccf800b36039d1da55 (diff)
Add support for time_major shape format to the sequential RNN Op in TF Lite.
This option, if set, changes the shape format of the inputs and outputs to [max_time, batch_size, depth]. If false, it uses [batch_size, max_time, depth]. By default, it is set to false. PiperOrigin-RevId: 182569507
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h5
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc139
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc102
-rw-r--r--tensorflow/contrib/lite/model.cc12
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs7
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h176
6 files changed, 377 insertions, 64 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 0c333f9e8c..3b43a1fd5d 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -83,6 +83,11 @@ typedef struct {
TfLiteFusedActivation activation;
} TfLiteRNNParams;
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams;
typedef enum {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 85e09049ee..f5f1ec2cf3 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -34,7 +34,7 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int KHiddenStateTensor = 0;
+constexpr int kHiddenStateTensor = 0;
constexpr int kOutputTensor = 1;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
@@ -51,8 +51,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check all the parameters of tensor match within themselves and match the
// input configuration.
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int num_units = input_weights->dims->data[0];
TF_LITE_ASSERT_EQ(input->dims->data[2], input_weights->dims->data[1]);
TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
@@ -60,7 +64,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ &context->tensors[node->outputs->data[kHiddenStateTensor]];
TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
// Resize state.
@@ -75,8 +79,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
- output_size_array->data[0] = batch_size;
- output_size_array->data[1] = max_time;
+ output_size_array->data[0] = (time_major) ? max_time : batch_size;
+ output_size_array->data[1] = (time_major) ? batch_size : max_time;
output_size_array->data[2] = num_units;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output,
output_size_array));
@@ -84,8 +88,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+namespace {
+void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int num_units, int input_weights_stride,
+ int recurrent_weights_stride, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ // Output = bias
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = bias_ptr[o];
+ }
+
+ // Output += input * input_weights
+ for (int o = 0; o < num_units; o++) {
+ for (int i = 0; i < input_size; i++) {
+ output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
+ }
+ input_weights_ptr += input_weights_stride;
+ }
+
+ // Output += recurrent_weights * hidden_state
+ for (int o = 0; o < num_units; o++) {
+ for (int h = 0; h < num_units; h++) {
+ output_ptr_batch[o] +=
+ hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+ }
+ recurrent_weights_ptr += recurrent_weights_stride;
+ }
+
+ // Output = activation(Output) and update hidden_state
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
+ hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+}
+} // namespace
+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
+ auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
TfLiteTensor* input_weights =
@@ -94,61 +134,60 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
&context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
TfLiteTensor* hidden_state =
- &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ &context->tensors[node->outputs->data[kHiddenStateTensor]];
TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
// Initialize the pointer bias.
const float* bias_ptr = bias->data.f;
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[2];
const int input_weights_stride = input_weights->dims->data[1];
const int recurrent_weights_stride = recurrent_weights->dims->data[1];
- // For each batch
- for (int b = 0; b < batch_size; b++) {
- // Initialize the pointer to hidden state.
- float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
- for (int s = 0; s < max_time; s++) {
- // Initialize the pointer to input and output.
- const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
- float* output_ptr_batch =
- output->data.f + b * num_units * max_time + s * num_units;
-
- // Initialize input_weights and recurrent_weights.
- const float* input_weights_ptr = input_weights->data.f;
- const float* recurrent_weights_ptr = recurrent_weights->data.f;
-
- // Output = bias
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = bias_ptr[o];
- }
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
- // Output += input * input_weights
- for (int o = 0; o < num_units; o++) {
- for (int i = 0; i < input_size; i++) {
- output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
- }
- input_weights_ptr += input_weights_stride;
- }
-
- // Output += recurrent_weights * hidden_state
- for (int o = 0; o < num_units; o++) {
- for (int h = 0; h < num_units; h++) {
- output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
- }
- recurrent_weights_ptr += recurrent_weights_stride;
+ if (time_major) {
+ // Unroll the sequence
+ for (int s = 0; s < max_time; s++) {
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size + b * input_size;
+ float* output_ptr_batch =
+ output->data.f + s * num_units * batch_size + b * num_units;
+
+ RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
+ bias_ptr, input_size, num_units, input_weights_stride,
+ recurrent_weights_stride, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
}
-
- // Output = activation(Output) and update hidden_state
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] =
- (ActivationFunctor(params->activation))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+ } else {
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+ for (int s = 0; s < max_time; s++) {
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ float* output_ptr_batch =
+ output->data.f + b * num_units * max_time + s * num_units;
+
+ RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
+ bias_ptr, input_size, num_units, input_weights_stride,
+ recurrent_weights_stride, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
}
}
}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index a1c1eda160..82c680ec3d 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Unit test for TFLite RNN op.
+// Unit test for TFLite Sequential RNN op.
#include <vector>
#include <iomanip>
@@ -125,7 +125,8 @@ static float rnn_golden_output[] = {
class UnidirectionalRNNOpModel : public SingleOpModel {
public:
- UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size)
+ UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size,
+ bool time_major)
: batches_(batches),
sequence_len_(sequence_len),
units_(units),
@@ -136,13 +137,22 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_RNNOptions,
- CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, sequence_len_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOptions_SequenceRNNOptions,
+ CreateSequenceRNNOptions(builder_, time_major,
+ ActivationFunctionType_RELU)
+ .Union());
+ if (time_major) {
+ BuildInterpreter({{sequence_len_, batches_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ } else {
+ BuildInterpreter({{batches_, sequence_len_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ }
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -195,7 +205,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
// TODO(mirkov): add another test which directly compares to TF once TOCO
// supports the conversion from dynamic_rnn with BasicRNNCell.
TEST(FullyConnectedOpTest, BlackBoxTest) {
- UnidirectionalRNNOpModel rnn(2, 16, 16, 8);
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8, /*time_major=*/false);
rnn.SetWeights(
{0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
@@ -260,6 +271,77 @@ TEST(FullyConnectedOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
}
+TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) {
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8, /*time_major=*/true);
+ rnn.SetWeights(
+ {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818});
+
+ rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
+ -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
+ 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
+ -0.37609905});
+
+ rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1});
+
+ rnn.ResetHiddenState();
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ // The two batches are identical.
+ rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
+ rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+ }
+
+ rnn.Invoke();
+
+ std::vector<float> expected;
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* golden_batch_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_batch_end = golden_batch_start + rnn.num_units();
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ }
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 86e613736d..4b0c853f77 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -339,7 +339,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+ TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ if (auto* sequence_rnn_params =
+ op->builtin_options_as_SequenceRNNOptions()) {
+ params->activation =
+ parse_activation(sequence_rnn_params->fused_activation_function());
+ params->time_major = sequence_rnn_params->time_major();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RNN: {
TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index f5251031b3..260a87c93b 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -151,6 +151,7 @@ union BuiltinOptions {
SubOptions,
DivOptions,
SqueezeOptions,
+ SequenceRNNOptions,
}
enum Padding : byte { SAME, VALID }
@@ -214,6 +215,12 @@ table RNNOptions {
fused_activation_function:ActivationFunctionType;
}
+// An implementation of TensorFlow dynamic_rnn with RNNCell.
+table SequenceRNNOptions {
+ time_major:bool;
+ fused_activation_function:ActivationFunctionType;
+}
+
// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
table FullyConnectedOptions {
fused_activation_function:ActivationFunctionType;
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index a2ec8e40e9..fd98be8f70 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -48,6 +48,9 @@ struct SVDFOptionsT;
struct RNNOptions;
struct RNNOptionsT;
+struct SequenceRNNOptions;
+struct SequenceRNNOptionsT;
+
struct FullyConnectedOptions;
struct FullyConnectedOptionsT;
@@ -339,11 +342,12 @@ enum BuiltinOptions {
BuiltinOptions_SubOptions = 28,
BuiltinOptions_DivOptions = 29,
BuiltinOptions_SqueezeOptions = 30,
+ BuiltinOptions_SequenceRNNOptions = 31,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SqueezeOptions
+ BuiltinOptions_MAX = BuiltinOptions_SequenceRNNOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[32] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -375,7 +379,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[31] {
BuiltinOptions_MeanOptions,
BuiltinOptions_SubOptions,
BuiltinOptions_DivOptions,
- BuiltinOptions_SqueezeOptions};
+ BuiltinOptions_SqueezeOptions,
+ BuiltinOptions_SequenceRNNOptions};
return values;
}
@@ -411,6 +416,7 @@ inline const char **EnumNamesBuiltinOptions() {
"SubOptions",
"DivOptions",
"SqueezeOptions",
+ "SequenceRNNOptions",
nullptr};
return names;
}
@@ -579,6 +585,11 @@ struct BuiltinOptionsTraits<SqueezeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions;
};
+template <>
+struct BuiltinOptionsTraits<SequenceRNNOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -926,6 +937,16 @@ struct BuiltinOptionsUnion {
? reinterpret_cast<const SqueezeOptionsT *>(value)
: nullptr;
}
+ SequenceRNNOptionsT *AsSequenceRNNOptions() {
+ return type == BuiltinOptions_SequenceRNNOptions
+ ? reinterpret_cast<SequenceRNNOptionsT *>(value)
+ : nullptr;
+ }
+ const SequenceRNNOptionsT *AsSequenceRNNOptions() const {
+ return type == BuiltinOptions_SequenceRNNOptions
+ ? reinterpret_cast<const SequenceRNNOptionsT *>(value)
+ : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -1886,6 +1907,77 @@ flatbuffers::Offset<RNNOptions> CreateRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SequenceRNNOptionsT : public flatbuffers::NativeTable {
+ typedef SequenceRNNOptions TableType;
+ bool time_major;
+ ActivationFunctionType fused_activation_function;
+ SequenceRNNOptionsT()
+ : time_major(false),
+ fused_activation_function(ActivationFunctionType_NONE) {}
+};
+
+struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SequenceRNNOptionsT NativeTableType;
+ enum { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 };
+ bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; }
+ ActivationFunctionType fused_activation_function() const {
+ return static_cast<ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ verifier.EndTable();
+ }
+ SequenceRNNOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ SequenceRNNOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SequenceRNNOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SequenceRNNOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_time_major(bool time_major) {
+ fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_TIME_MAJOR,
+ static_cast<uint8_t>(time_major), 0);
+ }
+ void add_fused_activation_function(
+ ActivationFunctionType fused_activation_function) {
+ fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SequenceRNNOptionsBuilder &operator=(const SequenceRNNOptionsBuilder &);
+ flatbuffers::Offset<SequenceRNNOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SequenceRNNOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false,
+ ActivationFunctionType fused_activation_function =
+ ActivationFunctionType_NONE) {
+ SequenceRNNOptionsBuilder builder_(_fbb);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_time_major(time_major);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
typedef FullyConnectedOptions TableType;
ActivationFunctionType fused_activation_function;
@@ -3716,6 +3808,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
? static_cast<const SqueezeOptions *>(builtin_options())
: nullptr;
}
+ const SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const {
+ return builtin_options_type() == BuiltinOptions_SequenceRNNOptions
+ ? static_cast<const SequenceRNNOptions *>(builtin_options())
+ : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -3917,6 +4014,12 @@ inline const SqueezeOptions *Operator::builtin_options_as<SqueezeOptions>()
return builtin_options_as_SqueezeOptions();
}
+template <>
+inline const SequenceRNNOptions *
+Operator::builtin_options_as<SequenceRNNOptions>() const {
+ return builtin_options_as_SequenceRNNOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -4841,6 +4944,51 @@ inline flatbuffers::Offset<RNNOptions> CreateRNNOptions(
return tflite::CreateRNNOptions(_fbb, _fused_activation_function);
}
+inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SequenceRNNOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SequenceRNNOptions::UnPackTo(
+ SequenceRNNOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = time_major();
+ _o->time_major = _e;
+ }
+ {
+ auto _e = fused_activation_function();
+ _o->fused_activation_function = _e;
+ }
+}
+
+inline flatbuffers::Offset<SequenceRNNOptions> SequenceRNNOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSequenceRNNOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SequenceRNNOptions> CreateSequenceRNNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const SequenceRNNOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _time_major = _o->time_major;
+ auto _fused_activation_function = _o->fused_activation_function;
+ return tflite::CreateSequenceRNNOptions(_fbb, _time_major,
+ _fused_activation_function);
+}
+
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new FullyConnectedOptionsT();
@@ -6397,6 +6545,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
auto ptr = reinterpret_cast<const SqueezeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default:
return false;
}
@@ -6541,6 +6693,10 @@ inline void *BuiltinOptionsUnion::UnPack(
auto ptr = reinterpret_cast<const SqueezeOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const SequenceRNNOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default:
return nullptr;
}
@@ -6672,6 +6828,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
auto ptr = reinterpret_cast<const SqueezeOptionsT *>(value);
return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const SequenceRNNOptionsT *>(value);
+ return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
+ }
default:
return 0;
}
@@ -6817,6 +6977,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
new SqueezeOptionsT(*reinterpret_cast<SqueezeOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ value = new SequenceRNNOptionsT(
+ *reinterpret_cast<SequenceRNNOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -6974,6 +7139,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SequenceRNNOptions: {
+ auto ptr = reinterpret_cast<SequenceRNNOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default:
break;
}