diff options
author | 2018-10-06 21:00:57 -0700 | |
---|---|---|
committer | 2018-10-06 21:06:15 -0700 | |
commit | 7fa6a6b42bc9d562e2b1cc765ca78d281b51f734 (patch) | |
tree | 19398616e50223b0876e6e9a987cc9908d404592 /tensorflow/contrib/lite/kernels | |
parent | e93a18954689b6d522560f5273f6d3320d545b2e (diff) |
Add SequenceLSTMOptions to schema to decouple the sequential Op from the LSTM.
PiperOrigin-RevId: 216066634
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc | 14 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc | 11 |
2 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index ec9cf38b83..89d57e4599 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = + reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( + node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Copy out the LSTM specific params so they can be passed in the function. + TfLiteLSTMParams lstm_params; + lstm_params.activation = params->activation; + lstm_params.cell_clip = params->cell_clip; + lstm_params.proj_clip = params->proj_clip; + switch (input_to_output_weights->type) { case kTfLiteFloat32: { return lstm_eval::EvalFloat( @@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, /*output_offset=*/0, scratch_buffer, activation_state, cell_state, output); } @@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, /*output_offset=*/0, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc index cd3aac0532..c97b0fdd61 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -110,11 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOptions_LSTMOptions, - CreateLSTMOptions(builder_, ActivationFunctionType_TANH, - cell_clip, proj_clip) - .Union()); + SetBuiltinOp( + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_UnidirectionalSequenceLSTMOptions, + CreateUnidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) + .Union()); BuildInterpreter(input_shapes); } |