aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index ec9cf38b83..89d57e4599 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params =
+ reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ // Copy out the LSTM specific params so they can be passed in the function.
+ TfLiteLSTMParams lstm_params;
+ lstm_params.activation = params->activation;
+ lstm_params.cell_clip = params->cell_clip;
+ lstm_params.proj_clip = params->proj_clip;
+
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
return lstm_eval::EvalFloat(
@@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, params, /*forward_sequence=*/true,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output);
}
@@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, params, /*forward_sequence=*/true,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,