aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 01:22:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 01:29:25 -0700
commit3b94d75a9e10ef8ef33760d0ef6aad326e1353ba (patch)
tree402934b406e63ccd9cff0faec8a83aba6d58abf3 /tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
parent57d31aa599c83014397a22bbb8f1a27a33b0ade3 (diff)
Merge the different LSTM EvalFloat/EvalHybrid calls into a single file.
PiperOrigin-RevId: 215870962
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc310
1 files changed, 28 insertions, 282 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 63817bd886..ec9cf38b83 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -429,273 +430,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_forget_weights_ptr, input_to_cell_weights_ptr,
- input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr,
- recurrent_to_output_weights_ptr, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
- cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_input_weights_scale, input_to_forget_weights_ptr,
- input_to_forget_weights_scale, input_to_cell_weights_ptr,
- input_to_cell_weights_scale, input_to_output_weights_ptr,
- input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors_ptr,
- prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_activation_state_ptr,
- quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -750,15 +484,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -771,17 +511,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",