aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 01:22:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 01:29:25 -0700
commit3b94d75a9e10ef8ef33760d0ef6aad326e1353ba (patch)
tree402934b406e63ccd9cff0faec8a83aba6d58abf3 /tensorflow/contrib/lite/kernels/internal
parent57d31aa599c83014397a22bbb8f1a27a33b0ade3 (diff)
Merge the different LSTM EvalFloat/EvalHybrid calls into a single file.
PiperOrigin-RevId: 215870962
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc598
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h184
2 files changed, 0 insertions, 782 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 56e9367878..083e5839bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -169,603 +169,5 @@ void RnnBatchStep(
hidden_state_ptr_batch);
}
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
- n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-}
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // If auxiliary input is available then compute aux_input_weight * aux_input
- if (aux_input_ptr_batch != nullptr) {
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, input_gate_scratch,
- /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
- }
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_batch, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_input_weights_scale=*/0.0f,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_scale=*/0.0f,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_scale=*/0.0f,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_scale=*/0.0f,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input,
- /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors,
- product_scaling_factors, recovered_cell_weights,
- quantized_input_ptr_batch,
- /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
-
- void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr,
- float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr,
- float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch,
- int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
- float* output_state_ptr, float* cell_state_ptr,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we
- // can check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
- n_batch, input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
- n_batch, output_gate_scratch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input,
- quantized_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (aux_input_ptr_batch != nullptr &&
- !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- aux_input_ptr_batch + offset, n_input,
- quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(
- output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
- cell_state_ptr, n_batch * n_cell,
- cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell,
- cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell,
- output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch,
- n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell,
- quantized_cell_state_ptr, product_scaling_factors, n_batch,
- output_ptr_batch,
- /*result_stride=*/1);
- }
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
- }
-
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index b5558cce55..74e0a4a53d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -76,190 +76,6 @@ void RnnBatchStep(
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch);
-// Performs an LSTM batch inference step for input specified by input_ptr_batch.
-// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
-// biases (*_bias_ptr), and buffers (*_scratch), along with additional
-// parameters:
-// - params: various LSTM params including activation, clipping, etc.,
-// - n_batch: size of batch,
-// - n_cell: number of cells (or units),
-// - n_input: the input size,
-// - n_output: the output size.
-//
-// The pointers to the cell and output state and the output are updated.
-//
-// The pointers with the suffix "_batch" point to data aligned in batch_major
-// order, and each step processes batch_size many inputs from input_ptr_batch,
-// and updates batch_size many cell and output states.
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but includes an auxiliary input with the corresponding weights.
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but with quantized weight matrices. In detail:
-// Input of size 'n_batch * n_input':
-// input_ptr_batch
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-// input_to_input_weights - optional (can be nullptr)
-// input_to_forget_weights
-// input_to_cell_weights
-// input_to_input_weights
-// Quantized recurrent weights of size 'n_cell * n_output':
-// recurrent_to_input_weights - optional
-// recurrent_to_forget_weights
-// recurrent_to_cell_weights
-// recurrent_to_input_weights
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-// cell_to_input_weights - optional
-// cell_to_cell_weights - optional
-// cell_to_output_weights - optional
-// Quantized projection weights of size 'n_output * n_cell'
-// projection_weights_ptr - optional
-// Weight scales (scalars) for each of the weights above.
-// input_to_input_weights_scale - optional
-// input_to_forget_weights_scale
-// input_to_cell_weights_scale
-// input_to_output_weights_scale
-// recurrent_to_input_weights_scale - optional
-// recurrent_to_forget_weights_scale
-// recurrent_to_cell_weights_scale
-// recurrent_to_output_weights_scale
-// cell_to_input_weights_scale,
-// cell_to_forget_weights_scale,
-// cell_to_output_weights_scale,
-// projection_weights_scale - optional
-// Gate biases of size 'n_cell':
-// input_gate_bias_ptr - optional
-// forget_gate_bias_ptr
-// cell_gate_bias_ptr
-// output_gate_bias_ptr
-//
-// Temporary pre-allocated storage for quantized values:
-// quantized_input_ptr_batch (same size as input_ptr_batch)
-// quantized_output_state_ptr (same size as output_state_ptr)
-// quantized_cell_state_ptr (same size as cell_state_ptr)
-// Temporary pre-allocated storage for recovered values:
-// recovered_cell_weights (same size as cell_to_*_weights)
-//
-// Outputs:
-// output_state_ptr - size 'n_batch * n_output'
-// cell_state_ptr - size 'n_batch * n_cell'
-// output_ptr_batch - size 'n_batch * n_output'
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* scaling_factors, float* product_scaling_factors,
- float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_