diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc | 472 |
1 files changed, 405 insertions, 67 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 1c28123a24..0acd705950 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <unistd.h> #include <cassert> #include <cmath> #include <cstdio> @@ -70,9 +69,21 @@ constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; +// Temporary tensors +enum TemporaryTensor { + kScratchBuffer = 0, + kInputQuantized = 1, + kOutputStateQuantized = 2, + kCellStateQuantized = 3, + kScalingFactors = 4, + kProductScalingFactors = 5, + kRecoveredCellWeights = 6, + kNumTemporaryTensors = 7 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -84,7 +95,7 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -242,6 +253,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; @@ -261,7 +273,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -288,86 +301,156 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + // Create a scratch buffer tensor. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[kOutputStateQuantized] = + *scratch_tensor_index + kOutputStateQuantized; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, kOutputStateQuantized); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[kCellStateQuantized] = + *scratch_tensor_index + kCellStateQuantized; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, kCellStateQuantized); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -380,8 +463,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -432,6 +513,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_state_ptr = output_state->data.f; float* cell_state_ptr = cell_state->data.f; + // Feed the sequence into the LSTM step-by-step. for (int t = 0; t < max_time; t++) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_batch = output->data.f + t * n_batch * n_output; @@ -452,6 +534,262 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast<int8_t*>(projection_weights->data.uint8); + float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast<int8_t*>(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast<int8_t*>(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_batch = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, + input_to_input_weights_scale, input_to_forget_weights_ptr, + input_to_forget_weights_scale, input_to_cell_weights_ptr, + input_to_cell_weights_scale, input_to_output_weights_ptr, + input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, + n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr_batch); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} } // namespace unidirectional_sequence_lstm TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { |