aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 01:22:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 01:29:25 -0700
commit3b94d75a9e10ef8ef33760d0ef6aad326e1353ba (patch)
tree402934b406e63ccd9cff0faec8a83aba6d58abf3 /tensorflow/contrib/lite/kernels
parent57d31aa599c83014397a22bbb8f1a27a33b0ade3 (diff)
Merge the different LSTM EvalFloat/EvalHybrid calls into a single file.
PiperOrigin-RevId: 215870962
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD13
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc333
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc598
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h184
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc300
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.cc909
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.h79
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc310
8 files changed, 1061 insertions, 1665 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 95e387814d..68636fb070 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -234,11 +234,11 @@ cc_library(
":activation_functor",
":eigen_support",
":kernel_util",
+ ":lstm_eval",
":op_macros",
":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite:util",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
@@ -255,6 +255,17 @@ cc_library(
)
cc_library(
+ name = "lstm_eval",
+ srcs = ["lstm_eval.cc"],
+ hdrs = ["lstm_eval.h"],
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ ],
+)
+
+cc_library(
name = "builtin_ops",
srcs = ["register.cc"],
hdrs = ["register.h"],
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 0532528f52..a326827b1e 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -694,330 +695,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
- const TfLiteTensor* aux_input_to_input_weights,
- const TfLiteTensor* aux_input_to_forget_weights,
- const TfLiteTensor* aux_input_to_cell_weights,
- const TfLiteTensor* aux_input_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
- TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- float* aux_input_ptr = nullptr;
- float* aux_input_to_input_weights_ptr = nullptr;
- float* aux_input_to_forget_weights_ptr = nullptr;
- float* aux_input_to_cell_weights_ptr = nullptr;
- float* aux_input_to_output_weights_ptr = nullptr;
- if (aux_input_size > 0) {
- aux_input_ptr = aux_input->data.f;
- aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
- aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
- aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
- aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
- }
-
- // Loop through the sequence.
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * output->dims->data[2];
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr_time =
- output->data.f + t_rel * output_step + output_offset;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
- input_to_cell_weights->data.f, input_to_output_weights->data.f,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
- aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
- const TfLiteTensor* aux_input_to_input_weights,
- const TfLiteTensor* aux_input_to_forget_weights,
- const TfLiteTensor* aux_input_to_cell_weights,
- const TfLiteTensor* aux_input_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
- TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
- TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* output_state_ptr = output_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_aux_input_ptr =
- (aux_input_quantized == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Auxiliary input and weights.
- float* aux_input_ptr = nullptr;
- int8_t* aux_input_to_input_weights_ptr = nullptr;
- int8_t* aux_input_to_forget_weights_ptr = nullptr;
- int8_t* aux_input_to_cell_weights_ptr = nullptr;
- int8_t* aux_input_to_output_weights_ptr = nullptr;
- float aux_input_to_input_weights_scale = 0.0f;
- float aux_input_to_forget_weights_scale = 0.0f;
- float aux_input_to_cell_weights_scale = 0.0f;
- float aux_input_to_output_weights_scale = 0.0f;
- if (aux_input_size > 0) {
- aux_input_ptr = aux_input->data.f;
- aux_input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
- aux_input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
- aux_input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
- aux_input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
- aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
- aux_input_to_forget_weights_scale =
- aux_input_to_forget_weights->params.scale;
- aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
- aux_input_to_output_weights_scale =
- aux_input_to_output_weights->params.scale;
- }
-
- // Feed the sequence into the LSTM step-by-step.
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * output->dims->data[2];
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr = output->data.f + t_rel * output_step + output_offset;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, aux_input_size, n_output, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
- }
-
- return kTfLiteOk;
-}
-
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
@@ -1157,7 +834,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (fw_input_to_output_weights->type) {
case kTfLiteFloat32: {
- TfLiteStatus fw_pass_status = EvalFloat(
+ TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
@@ -1172,7 +849,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
- TfLiteStatus bw_pass_status = EvalFloat(
+ TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
@@ -1208,7 +885,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
- TfLiteStatus fw_pass_status = EvalHybrid(
+ TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
@@ -1226,7 +903,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
- TfLiteStatus bw_pass_status = EvalHybrid(
+ TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 56e9367878..083e5839bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -169,603 +169,5 @@ void RnnBatchStep(
hidden_state_ptr_batch);
}
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
- n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-}
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // If auxiliary input is available then compute aux_input_weight * aux_input
- if (aux_input_ptr_batch != nullptr) {
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, input_gate_scratch,
- /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
- }
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_batch, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_input_weights_scale=*/0.0f,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_scale=*/0.0f,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_scale=*/0.0f,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_scale=*/0.0f,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input,
- /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors,
- product_scaling_factors, recovered_cell_weights,
- quantized_input_ptr_batch,
- /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
-
- void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr,
- float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr,
- float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch,
- int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
- float* output_state_ptr, float* cell_state_ptr,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we
- // can check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
- n_batch, input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
- n_batch, output_gate_scratch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input,
- quantized_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (aux_input_ptr_batch != nullptr &&
- !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- aux_input_ptr_batch + offset, n_input,
- quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(
- output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
- cell_state_ptr, n_batch * n_cell,
- cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell,
- cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell,
- output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch,
- n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell,
- quantized_cell_state_ptr, product_scaling_factors, n_batch,
- output_ptr_batch,
- /*result_stride=*/1);
- }
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
- }
-
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index b5558cce55..74e0a4a53d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -76,190 +76,6 @@ void RnnBatchStep(
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch);
-// Performs an LSTM batch inference step for input specified by input_ptr_batch.
-// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
-// biases (*_bias_ptr), and buffers (*_scratch), along with additional
-// parameters:
-// - params: various LSTM params including activation, clipping, etc.,
-// - n_batch: size of batch,
-// - n_cell: number of cells (or units),
-// - n_input: the input size,
-// - n_output: the output size.
-//
-// The pointers to the cell and output state and the output are updated.
-//
-// The pointers with the suffix "_batch" point to data aligned in batch_major
-// order, and each step processes batch_size many inputs from input_ptr_batch,
-// and updates batch_size many cell and output states.
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but includes an auxiliary input with the corresponding weights.
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but with quantized weight matrices. In detail:
-// Input of size 'n_batch * n_input':
-// input_ptr_batch
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-// input_to_input_weights - optional (can be nullptr)
-// input_to_forget_weights
-// input_to_cell_weights
-// input_to_input_weights
-// Quantized recurrent weights of size 'n_cell * n_output':
-// recurrent_to_input_weights - optional
-// recurrent_to_forget_weights
-// recurrent_to_cell_weights
-// recurrent_to_input_weights
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-// cell_to_input_weights - optional
-// cell_to_cell_weights - optional
-// cell_to_output_weights - optional
-// Quantized projection weights of size 'n_output * n_cell'
-// projection_weights_ptr - optional
-// Weight scales (scalars) for each of the weights above.
-// input_to_input_weights_scale - optional
-// input_to_forget_weights_scale
-// input_to_cell_weights_scale
-// input_to_output_weights_scale
-// recurrent_to_input_weights_scale - optional
-// recurrent_to_forget_weights_scale
-// recurrent_to_cell_weights_scale
-// recurrent_to_output_weights_scale
-// cell_to_input_weights_scale,
-// cell_to_forget_weights_scale,
-// cell_to_output_weights_scale,
-// projection_weights_scale - optional
-// Gate biases of size 'n_cell':
-// input_gate_bias_ptr - optional
-// forget_gate_bias_ptr
-// cell_gate_bias_ptr
-// output_gate_bias_ptr
-//
-// Temporary pre-allocated storage for quantized values:
-// quantized_input_ptr_batch (same size as input_ptr_batch)
-// quantized_output_state_ptr (same size as output_state_ptr)
-// quantized_cell_state_ptr (same size as cell_state_ptr)
-// Temporary pre-allocated storage for recovered values:
-// recovered_cell_weights (same size as cell_to_*_weights)
-//
-// Outputs:
-// output_state_ptr - size 'n_batch * n_output'
-// cell_state_ptr - size 'n_batch * n_cell'
-// output_ptr_batch - size 'n_batch * n_output'
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* scaling_factors, float* product_scaling_factors,
- float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 5b996d00bc..16d67a1a93 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -424,263 +425,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- activation_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_activation_state_ptr, quantized_cell_state_ptr,
- activation_state_ptr, cell_state_ptr, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
@@ -738,15 +482,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): add a check that weights are all uint8s or all floats.
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -759,17 +509,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc
new file mode 100644
index 0000000000..c6c21eb085
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc
@@ -0,0 +1,909 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+namespace {
+
+// Performs an LSTM batch inference step for input specified by input_ptr_batch.
+// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
+// biases (*_bias_ptr), and buffers (*_scratch), along with additional
+// parameters:
+// - params: various LSTM params including activation, clipping, etc.,
+// - n_batch: size of batch,
+// - n_cell: number of cells (or units),
+// - n_input: the input size,
+// - n_output: the output size.
+//
+// The pointers to the cell and output state and the output are updated.
+//
+// The pointers with the suffix "_batch" point to data aligned in batch_major
+// order, and each step processes batch_size many inputs from input_ptr_batch,
+// and updates batch_size many cell and output states.
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // If auxiliary input is available then compute aux_input_weight * aux_input
+ if (aux_input_ptr_batch != nullptr) {
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+// input_ptr_batch
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+// input_to_input_weights - optional (can be nullptr)
+// input_to_forget_weights
+// input_to_cell_weights
+// input_to_input_weights
+// Quantized recurrent weights of size 'n_cell * n_output':
+// recurrent_to_input_weights - optional
+// recurrent_to_forget_weights
+// recurrent_to_cell_weights
+// recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+// cell_to_input_weights - optional
+// cell_to_cell_weights - optional
+// cell_to_output_weights - optional
+// Quantized projection weights of size 'n_output * n_cell'
+// projection_weights_ptr - optional
+// Weight scales (scalars) for each of the weights above.
+// input_to_input_weights_scale - optional
+// input_to_forget_weights_scale
+// input_to_cell_weights_scale
+// input_to_output_weights_scale
+// recurrent_to_input_weights_scale - optional
+// recurrent_to_forget_weights_scale
+// recurrent_to_cell_weights_scale
+// recurrent_to_output_weights_scale
+// cell_to_input_weights_scale,
+// cell_to_forget_weights_scale,
+// cell_to_output_weights_scale,
+// projection_weights_scale - optional
+// Gate biases of size 'n_cell':
+// input_gate_bias_ptr - optional
+// forget_gate_bias_ptr
+// cell_gate_bias_ptr
+// output_gate_bias_ptr
+//
+// Temporary pre-allocated storage for quantized values:
+// quantized_input_ptr_batch (same size as input_ptr_batch)
+// quantized_output_state_ptr (same size as output_state_ptr)
+// quantized_cell_state_ptr (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+// recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+// output_state_ptr - size 'n_batch * n_output'
+// cell_state_ptr - size 'n_batch * n_cell'
+// output_ptr_batch - size 'n_batch * n_output'
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* scaling_factors, float* product_scaling_factors,
+ float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we
+ // can check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (aux_input_ptr_batch != nullptr &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, n_input,
+ quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+} // namespace
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ float* aux_input_ptr = nullptr;
+ float* aux_input_to_input_weights_ptr = nullptr;
+ float* aux_input_to_forget_weights_ptr = nullptr;
+ float* aux_input_to_cell_weights_ptr = nullptr;
+ float* aux_input_to_output_weights_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+ aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+ aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+ aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+ }
+
+ // Loop through the sequence.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time =
+ output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_aux_input_ptr =
+ (aux_input_quantized == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Auxiliary input and weights.
+ float* aux_input_ptr = nullptr;
+ int8_t* aux_input_to_input_weights_ptr = nullptr;
+ int8_t* aux_input_to_forget_weights_ptr = nullptr;
+ int8_t* aux_input_to_cell_weights_ptr = nullptr;
+ int8_t* aux_input_to_output_weights_ptr = nullptr;
+ float aux_input_to_input_weights_scale = 0.0f;
+ float aux_input_to_forget_weights_scale = 0.0f;
+ float aux_input_to_cell_weights_scale = 0.0f;
+ float aux_input_to_output_weights_scale = 0.0f;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+ aux_input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+ aux_input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+ aux_input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+ aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+ aux_input_to_forget_weights_scale =
+ aux_input_to_forget_weights->params.scale;
+ aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+ aux_input_to_output_weights_scale =
+ aux_input_to_output_weights->params.scale;
+ }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/contrib/lite/kernels/lstm_eval.h
new file mode 100644
index 0000000000..adf8cf0f64
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output);
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 63817bd886..ec9cf38b83 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -429,273 +430,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_forget_weights_ptr, input_to_cell_weights_ptr,
- input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr,
- recurrent_to_output_weights_ptr, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
- cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_input_weights_scale, input_to_forget_weights_ptr,
- input_to_forget_weights_scale, input_to_cell_weights_ptr,
- input_to_cell_weights_scale, input_to_output_weights_ptr,
- input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors_ptr,
- prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_activation_state_ptr,
- quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -750,15 +484,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -771,17 +511,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",