aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-05 12:11:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 12:13:42 -0700
commit920df27282b3f5d03d79f54ef05cea305c2a30d7 (patch)
tree1ed1f26e6c000bb28ff82fd720b7427ac0c9dfac /tensorflow/contrib/lite/kernels/lstm.cc
parent62a70dd873bc8488b10df5ad55254119173a5d0c (diff)
Implementation of the symmetrically quantized LSTM TFLite Op.
PiperOrigin-RevId: 199337082
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc454
1 files changed, 384 insertions, 70 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 9aae3e571b..eb26a02455 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -86,7 +86,8 @@ constexpr int kOutputTensor = 2;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData;
op_data->kernel_type = kTfLiteLSTMFullKernel;
- context->AddTensors(context, 1, &op_data->scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/7,
+ &op_data->scratch_tensor_index);
return op_data;
}
@@ -94,7 +95,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -104,7 +105,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- if (input_to_input_weights) {
+ if (input_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
@@ -124,7 +125,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- if (recurrent_to_input_weights) {
+ if (recurrent_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
n_cell);
@@ -214,7 +215,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- if (projection_weights) {
+ if (projection_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
@@ -222,7 +223,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
- if (projection_bias) {
+ if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
}
@@ -252,6 +253,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
@@ -296,86 +298,148 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
- // Create a scratch buffer tensor.
+ // Mark state tensors as persistent tensors.
+ output_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // The weights are of consistent type, so it suffices to check one.
+ // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
node->temporaries->data[0] = op_data->scratch_tensor_index;
+
+ // Create a scratch buffer tensor.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
if (use_cifg) {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 3;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
} else {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Input, Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 4;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // output_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* output_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ output_state_quantized->type = kTfLiteUInt8;
+ output_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(output_state_quantized->dims,
+ output_state->dims)) {
+ TfLiteIntArray* output_state_quantized_size =
+ TfLiteIntArrayCopy(output_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output_state_quantized,
+ output_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
}
return kTfLiteOk;
}
// The LSTM Op engine.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
-
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
-
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
-
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
-
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
-
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
// n_cell and n_output will be the same size when there is no projection.
@@ -387,9 +451,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
- // Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
-
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -457,6 +518,259 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
+ TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ kernel_utils::LstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr,
+ cell_state_ptr, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(mirkov): add a check that weights are all uint8s or all floats.
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params,
+ scratch_buffer, output_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* output_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params, scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, output_state_quantized, cell_state_quantized,
+ output_state, cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace full
// For basic kernel (5-inputs).
@@ -491,7 +805,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
- // Only Float32 is supportted currently.
+ // Only Float32 is supported currently.
// TODO(ycling): Implement quantize uint8 support.
for (int index = 0; index < node->inputs->size; ++index) {
TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];