aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/kernel_utils.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc39
1 files changed, 21 insertions, 18 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 360b472c45..b9dd40ddf9 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -203,9 +203,9 @@ void LstmStep(
cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, output_ptr_batch);
+ projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
+ n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
}
void LstmStepWithAuxInput(
@@ -227,8 +227,8 @@ void LstmStepWithAuxInput(
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we can
@@ -268,19 +268,20 @@ void LstmStepWithAuxInput(
if (aux_input_ptr_batch != nullptr) {
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
- n_batch, input_gate_scratch, /*result_stride=*/1);
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
- n_batch, forget_gate_scratch, /*result_stride=*/1);
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
n_batch, cell_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
- n_batch, output_gate_scratch, /*result_stride=*/1);
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
}
// For each batch and cell: compute recurrent_weight * output_state.
@@ -432,10 +433,11 @@ void LstmStep(
cell_to_output_weights_ptr, cell_to_output_weights_scale,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors, product_scaling_factors,
- recovered_cell_weights, quantized_input_ptr_batch,
+ projection_bias_ptr, params, n_batch, n_cell, n_input,
+ /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors,
+ product_scaling_factors, recovered_cell_weights,
+ quantized_input_ptr_batch,
/*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
output_ptr_batch);
@@ -476,8 +478,9 @@ void LstmStep(
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
float projection_weights_scale, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
float* product_scaling_factors, float* recovered_cell_weights,
int8_t* quantized_input_ptr_batch,
int8_t* quantized_aux_input_ptr_batch,