diff options
author | 2018-10-05 01:22:02 -0700 | |
---|---|---|
committer | 2018-10-05 01:29:25 -0700 | |
commit | 3b94d75a9e10ef8ef33760d0ef6aad326e1353ba (patch) | |
tree | 402934b406e63ccd9cff0faec8a83aba6d58abf3 /tensorflow/contrib/lite/kernels/internal | |
parent | 57d31aa599c83014397a22bbb8f1a27a33b0ade3 (diff) |
Merge the different LSTM EvalFloat/EvalHybrid calls into a single file.
PiperOrigin-RevId: 215870962
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/kernel_utils.cc | 598 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/kernel_utils.h | 184 |
2 files changed, 0 insertions, 782 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 56e9367878..083e5839bd 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -169,603 +169,5 @@ void RnnBatchStep( hidden_state_ptr_batch); } -void LstmStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch) { - LstmStepWithAuxInput( - input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, - input_to_cell_weights_ptr, input_to_output_weights_ptr, - /*aux_input_ptr_batch=*/nullptr, - /*aux_input_to_input_weights_ptr=*/nullptr, - /*aux_input_to_forget_weights_ptr=*/nullptr, - /*aux_input_to_cell_weights_ptr=*/nullptr, - /*aux_input_to_output_weights_ptr=*/nullptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, - recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, - cell_to_input_weights_ptr, cell_to_forget_weights_ptr, - cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0, - n_output, output_state_ptr, cell_state_ptr, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); -} - -void LstmStepWithAuxInput( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, - const float* aux_input_to_input_weights_ptr, - const float* aux_input_to_forget_weights_ptr, - const float* aux_input_to_cell_weights_ptr, - const float* aux_input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // If auxiliary input is available then compute aux_input_weight * aux_input - if (aux_input_ptr_batch != nullptr) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_aux_input, - aux_input_ptr_batch, n_batch, input_gate_scratch, - /*result_stride=*/1); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_aux_input, - aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_aux_input, - aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1); - } - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch, - /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch, - /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, - output_ptr_batch, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, - params->proj_clip, output_ptr_batch); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); -} - -void LstmStep( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch) { - LstmStepWithAuxInput( - input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - /*aux_input_ptr_batch=*/nullptr, - /*aux_input_to_input_weights_ptr=*/nullptr, - /*aux_input_to_input_weights_scale=*/0.0f, - /*aux_input_to_forget_weights_ptr=*/nullptr, - /*aux_input_to_forget_weights_scale=*/0.0f, - /*aux_input_to_cell_weights_ptr=*/nullptr, - /*aux_input_to_cell_weights_scale=*/0.0f, - /*aux_input_to_output_weights_ptr=*/nullptr, - /*aux_input_to_output_weights_scale=*/0.0f, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params, n_batch, n_cell, n_input, - /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, scaling_factors, - product_scaling_factors, recovered_cell_weights, - quantized_input_ptr_batch, - /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr, - quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, - output_ptr_batch); - } - - void LstmStepWithAuxInput( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, - float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, const float* aux_input_ptr_batch, - const int8_t* aux_input_to_input_weights_ptr, - float aux_input_to_input_weights_scale, - const int8_t* aux_input_to_forget_weights_ptr, - float aux_input_to_forget_weights_scale, - const int8_t* aux_input_to_cell_weights_ptr, - float aux_input_to_cell_weights_scale, - const int8_t* aux_input_to_output_weights_ptr, - float aux_input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, - float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_aux_input, int n_output, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, - int8_t* quantized_aux_input_ptr_batch, - int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, - float* output_state_ptr, float* cell_state_ptr, - float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we - // can check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, - n_batch, output_gate_scratch); - - if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, n_input, - quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - forget_gate_scratch, - /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - output_gate_scratch, - /*result_stride=*/1); - } - - if (aux_input_ptr_batch != nullptr && - !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr_batch + offset, n_input, - quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); - } - - if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_output; - tensor_utils::SymmetricQuantizeFloats( - output_state_ptr + offset, n_output, - quantized_output_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); - } - - // Save quantization and matmul computation for all zero input. - bool is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, - cell_to_input_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, - cell_to_forget_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, - cell_state_ptr, n_batch * n_cell, - cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, - cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } - - is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - // For each batch and cell: update the output gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, - cell_to_output_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, - output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - if (!tensor_utils::IsZeroVector(output_gate_scratch, - n_batch * n_cell)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_cell; - tensor_utils::SymmetricQuantizeFloats( - output_gate_scratch + offset, n_cell, - quantized_cell_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * projection_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, - quantized_cell_state_ptr, product_scaling_factors, n_batch, - output_ptr_batch, - /*result_stride=*/1); - } - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, - params->proj_clip, output_ptr_batch); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); - } - } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index b5558cce55..74e0a4a53d 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -76,190 +76,6 @@ void RnnBatchStep( int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, float* hidden_state_ptr_batch, float* output_ptr_batch); -// Performs an LSTM batch inference step for input specified by input_ptr_batch. -// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and -// biases (*_bias_ptr), and buffers (*_scratch), along with additional -// parameters: -// - params: various LSTM params including activation, clipping, etc., -// - n_batch: size of batch, -// - n_cell: number of cells (or units), -// - n_input: the input size, -// - n_output: the output size. -// -// The pointers to the cell and output state and the output are updated. -// -// The pointers with the suffix "_batch" point to data aligned in batch_major -// order, and each step processes batch_size many inputs from input_ptr_batch, -// and updates batch_size many cell and output states. -void LstmStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch); - -// Same as above but includes an auxiliary input with the corresponding weights. -void LstmStepWithAuxInput( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, - const float* aux_input_to_input_weights_ptr, - const float* aux_input_to_forget_weights_ptr, - const float* aux_input_to_cell_weights_ptr, - const float* aux_input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch); - -// Same as above but with quantized weight matrices. In detail: -// Input of size 'n_batch * n_input': -// input_ptr_batch -// -// LSTM weights: -// Quantized input weights of size 'n_cell * n_input': -// input_to_input_weights - optional (can be nullptr) -// input_to_forget_weights -// input_to_cell_weights -// input_to_input_weights -// Quantized recurrent weights of size 'n_cell * n_output': -// recurrent_to_input_weights - optional -// recurrent_to_forget_weights -// recurrent_to_cell_weights -// recurrent_to_input_weights -// Quantized peephole weights of size 'n_cell', representing diagonal matrices. -// cell_to_input_weights - optional -// cell_to_cell_weights - optional -// cell_to_output_weights - optional -// Quantized projection weights of size 'n_output * n_cell' -// projection_weights_ptr - optional -// Weight scales (scalars) for each of the weights above. -// input_to_input_weights_scale - optional -// input_to_forget_weights_scale -// input_to_cell_weights_scale -// input_to_output_weights_scale -// recurrent_to_input_weights_scale - optional -// recurrent_to_forget_weights_scale -// recurrent_to_cell_weights_scale -// recurrent_to_output_weights_scale -// cell_to_input_weights_scale, -// cell_to_forget_weights_scale, -// cell_to_output_weights_scale, -// projection_weights_scale - optional -// Gate biases of size 'n_cell': -// input_gate_bias_ptr - optional -// forget_gate_bias_ptr -// cell_gate_bias_ptr -// output_gate_bias_ptr -// -// Temporary pre-allocated storage for quantized values: -// quantized_input_ptr_batch (same size as input_ptr_batch) -// quantized_output_state_ptr (same size as output_state_ptr) -// quantized_cell_state_ptr (same size as cell_state_ptr) -// Temporary pre-allocated storage for recovered values: -// recovered_cell_weights (same size as cell_to_*_weights) -// -// Outputs: -// output_state_ptr - size 'n_batch * n_output' -// cell_state_ptr - size 'n_batch * n_cell' -// output_ptr_batch - size 'n_batch * n_output' -void LstmStep( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch); - -void LstmStepWithAuxInput( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, const float* aux_input_ptr_batch, - const int8_t* aux_input_to_input_weights_ptr, - float aux_input_to_input_weights_scale, - const int8_t* aux_input_to_forget_weights_ptr, - float aux_input_to_forget_weights_scale, - const int8_t* aux_input_to_cell_weights_ptr, - float aux_input_to_cell_weights_scale, - const int8_t* aux_input_to_output_weights_ptr, - float aux_input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_aux_input, int n_output, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* scaling_factors, float* product_scaling_factors, - float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, - int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch); - } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ |