diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-04 17:04:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 17:08:20 -0700 |
commit | 964c1dfcc9e55fbaf9e31efd310385b6fe2563d7 (patch) | |
tree | c97516bc5142fe2a5700d4fb7bcc83ffb9c34d50 /tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc | |
parent | a2e3dcdb4f439f05592b3e4698cb25a28d85a3b7 (diff) |
Add support for quantized (hybrid) bidirectional sequential LSTM Op.
PiperOrigin-RevId: 211552101
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc | 699 |
1 files changed, 546 insertions, 153 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index af47b33922..cde4f55a16 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -108,9 +108,26 @@ constexpr int kBwInputCellStateTensor = 38; constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; +// Temporary tensors. +enum TemporaryTensor { + // Scratch buffers for input, forget, etc. gates + kFwScratchBuffer = 0, + kBwScratchBuffer = 1, + // Quantized tensors needed for the hybrid kernel. + kInputQuantized = 2, + kFwActivationStateQuantized = 3, + kBwActivationStateQuantized = 4, + kFwCellStateQuantized = 5, + kBwCellStateQuantized = 6, + kScalingFactors = 7, + kProductScalingFactors = 8, + kRecoveredCellWeights = 9, + kNumTemporaryTensors = 10 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -131,7 +148,7 @@ TfLiteStatus CheckLstmTensorDimensions( int input_gate_bias_tensor, int forget_gate_bias_tensor, int cell_gate_bias_tensor, int output_gate_bias_tensor, int projection_weights_tensor, int projection_bias_tensor) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -324,7 +341,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, input->dims->size > 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->dims->size, 3); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -370,11 +388,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output, fw_output_size)); - // Create a scratch buffer tensor. + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); - node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. + } + // Create a scratch buffer tensor. + node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; + TfLiteTensor* fw_scratch_buffer = + GetTemporary(context, node, kFwScratchBuffer); fw_scratch_buffer->type = input->type; fw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -435,8 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell); // Create a scratch buffer tensor. - node->temporaries->data[1] = *(scratch_tensor_index) + 1; - TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1); + node->temporaries->data[kBwScratchBuffer] = + *(scratch_tensor_index) + kBwScratchBuffer; + TfLiteTensor* bw_scratch_buffer = + GetTemporary(context, node, kBwScratchBuffer); bw_scratch_buffer->type = input->type; bw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -454,18 +482,441 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, bw_scratch_buffer_size)); + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + + node->temporaries->data[kFwActivationStateQuantized] = + *scratch_tensor_index + kFwActivationStateQuantized; + TfLiteTensor* fw_activation_state_quantized = + GetTemporary(context, node, kFwActivationStateQuantized); + fw_activation_state_quantized->type = kTfLiteUInt8; + fw_activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims, + fw_activation_state->dims)) { + TfLiteIntArray* fw_activation_state_quantized_size = + TfLiteIntArrayCopy(fw_activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_activation_state_quantized, + fw_activation_state_quantized_size)); + } + node->temporaries->data[kBwActivationStateQuantized] = + *scratch_tensor_index + kBwActivationStateQuantized; + TfLiteTensor* bw_activation_state_quantized = + GetTemporary(context, node, kBwActivationStateQuantized); + bw_activation_state_quantized->type = kTfLiteUInt8; + bw_activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims, + bw_activation_state->dims)) { + TfLiteIntArray* bw_activation_state_quantized_size = + TfLiteIntArrayCopy(bw_activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_activation_state_quantized, + bw_activation_state_quantized_size)); + } + node->temporaries->data[kFwCellStateQuantized] = + *scratch_tensor_index + kFwCellStateQuantized; + TfLiteTensor* fw_cell_state_quantized = + GetTemporary(context, node, kFwCellStateQuantized); + fw_cell_state_quantized->type = kTfLiteUInt8; + fw_cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims, + fw_cell_state->dims)) { + TfLiteIntArray* fw_cell_state_quantized_size = + TfLiteIntArrayCopy(fw_cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, fw_cell_state_quantized, + fw_cell_state_quantized_size)); + } + node->temporaries->data[kBwCellStateQuantized] = + *scratch_tensor_index + kBwCellStateQuantized; + TfLiteTensor* bw_cell_state_quantized = + GetTemporary(context, node, kBwCellStateQuantized); + bw_cell_state_quantized->type = kTfLiteUInt8; + bw_cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims, + bw_cell_state->dims)) { + TfLiteIntArray* bw_cell_state_quantized_size = + TfLiteIntArrayCopy(bw_cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, bw_cell_state_quantized, + bw_cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_fw_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Loop through the sequence. + if (forward_sequence) { + for (int t = 0; t < max_time; t++) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr_time = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } else { + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr_time = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* output_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast<int8_t*>(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast<int8_t*>(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast<int8_t*>(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + if (forward_sequence) { + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, + n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } else { + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, + n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } + return kTfLiteOk; } // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); // Input tensor. const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; // Tensors for the forward cell. const TfLiteTensor* fw_input_to_input_weights = @@ -559,149 +1010,91 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kBwInputCellStateTensor); TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); - // n_cell and n_output will be the same size when there is no projection. - const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; - const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); - const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* fw_scratch_buffer = - &context->tensors[node->temporaries->data[0]]; - float* fw_input_gate_scratch = nullptr; - float* fw_cell_scratch = nullptr; - float* fw_forget_gate_scratch = nullptr; - float* fw_output_gate_scratch = nullptr; - if (fw_use_cifg) { - fw_cell_scratch = fw_scratch_buffer->data.f; - fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; - fw_output_gate_scratch = - fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; - } else { - fw_input_gate_scratch = fw_scratch_buffer->data.f; - fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; - fw_forget_gate_scratch = - fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; - fw_output_gate_scratch = - fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* fw_input_to_input_weights_ptr = - (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f; - const float* fw_recurrent_to_input_weights_ptr = - (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f; - const float* fw_input_gate_bias_ptr = - (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f; - const float* fw_cell_to_input_weights_ptr = - (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f - : nullptr; - const float* fw_cell_to_forget_weights_ptr = - (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr; - const float* fw_cell_to_output_weights_ptr = - (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr; - const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr) - ? nullptr - : fw_projection_weights->data.f; - const float* fw_projection_bias_ptr = - (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f; - - // Loop through the sequence. - for (int t = 0; t < max_time; t++) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output; - - kernel_utils::LstmStep( - input_ptr_batch, fw_input_to_input_weights_ptr, - fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f, - fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr, - fw_recurrent_to_forget_weights->data.f, - fw_recurrent_to_cell_weights->data.f, - fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr, - fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr, - fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f, - fw_cell_bias->data.f, fw_output_gate_bias->data.f, - fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch, - n_fw_cell, n_input, n_fw_output, fw_activation_state->data.f, - fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch, - fw_cell_scratch, fw_output_gate_scratch, output_ptr_time); - } - - // n_cell and n_output will be the same size when there is no projection. - const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; - const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); - const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. + GetTemporary(context, node, kFwScratchBuffer); TfLiteTensor* bw_scratch_buffer = - &context->tensors[node->temporaries->data[1]]; - float* bw_input_gate_scratch = nullptr; - float* bw_cell_scratch = nullptr; - float* bw_forget_gate_scratch = nullptr; - float* bw_output_gate_scratch = nullptr; - if (bw_use_cifg) { - bw_cell_scratch = bw_scratch_buffer->data.f; - bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; - bw_output_gate_scratch = - bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; - } else { - bw_input_gate_scratch = bw_scratch_buffer->data.f; - bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; - bw_forget_gate_scratch = - bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; - bw_output_gate_scratch = - bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch; + GetTemporary(context, node, kBwScratchBuffer); + + switch (fw_input_to_output_weights->type) { + case kTfLiteFloat32: { + TfLiteStatus fw_pass_status = EvalFloat( + input, fw_input_to_input_weights, fw_input_to_forget_weights, + fw_input_to_cell_weights, fw_input_to_output_weights, + fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, + fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + fw_cell_to_input_weights, fw_cell_to_forget_weights, + fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, + fw_cell_bias, fw_output_gate_bias, fw_projection_weights, + fw_projection_bias, params, /*forward_sequence=*/true, + fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); + TF_LITE_ENSURE_OK(context, fw_pass_status); + + TfLiteStatus bw_pass_status = EvalFloat( + input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input_to_cell_weights, bw_input_to_output_weights, + bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, + bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_cell_to_input_weights, bw_cell_to_forget_weights, + bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, + bw_cell_bias, bw_output_gate_bias, bw_projection_weights, + bw_projection_bias, params, /*forward_sequence=*/false, + bw_scratch_buffer, bw_activation_state, bw_cell_state, bw_output); + TF_LITE_ENSURE_OK(context, bw_pass_status); + return kTfLiteOk; + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + TfLiteTensor* fw_activation_state_quantized = + GetTemporary(context, node, kFwActivationStateQuantized); + TfLiteTensor* bw_activation_state_quantized = + GetTemporary(context, node, kBwActivationStateQuantized); + TfLiteTensor* fw_cell_state_quantized = + GetTemporary(context, node, kFwCellStateQuantized); + TfLiteTensor* bw_cell_state_quantized = + GetTemporary(context, node, kBwCellStateQuantized); + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + TfLiteStatus fw_pass_status = EvalHybrid( + input, fw_input_to_input_weights, fw_input_to_forget_weights, + fw_input_to_cell_weights, fw_input_to_output_weights, + fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, + fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + fw_cell_to_input_weights, fw_cell_to_forget_weights, + fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, + fw_cell_bias, fw_output_gate_bias, fw_projection_weights, + fw_projection_bias, params, /*forward_sequence=*/true, + fw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, + fw_activation_state_quantized, fw_cell_state_quantized, + fw_activation_state, fw_cell_state, fw_output); + TF_LITE_ENSURE_OK(context, fw_pass_status); + + TfLiteStatus bw_pass_status = EvalHybrid( + input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input_to_cell_weights, bw_input_to_output_weights, + bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, + bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_cell_to_input_weights, bw_cell_to_forget_weights, + bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, + bw_cell_bias, bw_output_gate_bias, bw_projection_weights, + bw_projection_bias, params, /*forward_sequence=*/false, + bw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, + bw_activation_state_quantized, bw_cell_state_quantized, + bw_activation_state, bw_cell_state, bw_output); + TF_LITE_ENSURE_OK(context, bw_pass_status); + return kTfLiteOk; + } + default: + context->ReportError(context, "Type %d is not currently supported.", + fw_input_to_output_weights->type); + return kTfLiteError; } - - // Check optional tensors, the respective pointers can be null. - const float* bw_input_to_input_weights_ptr = - (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f; - const float* bw_recurrent_to_input_weights_ptr = - (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f; - const float* bw_input_gate_bias_ptr = - (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f; - const float* bw_cell_to_input_weights_ptr = - (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f - : nullptr; - const float* bw_cell_to_forget_weights_ptr = - (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr; - const float* bw_cell_to_output_weights_ptr = - (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr; - const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr) - ? nullptr - : bw_projection_weights->data.f; - const float* bw_projection_bias_ptr = - (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f; - - // Loop through the sequence backwards. - for (int t = max_time - 1; t >= 0; t--) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output; - - kernel_utils::LstmStep( - input_ptr_batch, bw_input_to_input_weights_ptr, - bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f, - bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr, - bw_recurrent_to_forget_weights->data.f, - bw_recurrent_to_cell_weights->data.f, - bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr, - bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr, - bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f, - bw_cell_bias->data.f, bw_output_gate_bias->data.f, - bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch, - n_bw_cell, n_input, n_bw_output, bw_activation_state->data.f, - bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch, - bw_cell_scratch, bw_output_gate_scratch, output_ptr_time); - } - - // Backward step. return kTfLiteOk; } |