diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index ec9cf38b83..89d57e4599 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = + reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( + node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Copy out the LSTM specific params so they can be passed in the function. + TfLiteLSTMParams lstm_params; + lstm_params.activation = params->activation; + lstm_params.cell_clip = params->cell_clip; + lstm_params.proj_clip = params->proj_clip; + switch (input_to_output_weights->type) { case kTfLiteFloat32: { return lstm_eval::EvalFloat( @@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, /*output_offset=*/0, scratch_buffer, activation_state, cell_state, output); } @@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { /*aux_input_to_cell_weights=*/nullptr, /*aux_input_to_output_weights=*/nullptr, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, - projection_bias, params, /*forward_sequence=*/true, + projection_bias, &lstm_params, /*forward_sequence=*/true, /*output_offset=*/0, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, |