diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/kernel_utils.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/kernel_utils.cc | 39 |
1 files changed, 21 insertions, 18 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 360b472c45..b9dd40ddf9 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -203,9 +203,9 @@ void LstmStep( cell_to_input_weights_ptr, cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0, + n_output, output_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); } void LstmStepWithAuxInput( @@ -227,8 +227,8 @@ void LstmStepWithAuxInput( const float* forget_gate_bias_ptr, const float* cell_bias_ptr, const float* output_gate_bias_ptr, const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, + int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, + float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we can @@ -268,19 +268,20 @@ void LstmStepWithAuxInput( if (aux_input_ptr_batch != nullptr) { if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch, - n_batch, input_gate_scratch, /*result_stride=*/1); + aux_input_to_input_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, input_gate_scratch, + /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch, - n_batch, forget_gate_scratch, /*result_stride=*/1); + aux_input_to_forget_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch, + aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch, n_batch, cell_scratch, /*result_stride=*/1); tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch, - n_batch, output_gate_scratch, /*result_stride=*/1); + aux_input_to_output_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1); } // For each batch and cell: compute recurrent_weight * output_state. @@ -432,10 +433,11 @@ void LstmStep( cell_to_output_weights_ptr, cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors, product_scaling_factors, - recovered_cell_weights, quantized_input_ptr_batch, + projection_bias_ptr, params, n_batch, n_cell, n_input, + /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors, + product_scaling_factors, recovered_cell_weights, + quantized_input_ptr_batch, /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr_batch); @@ -476,8 +478,9 @@ void LstmStep( const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, float projection_weights_scale, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + int n_aux_input, int n_output, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* scaling_factors, float* product_scaling_factors, float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch, |