diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-05 01:22:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 01:29:25 -0700 |
commit | 3b94d75a9e10ef8ef33760d0ef6aad326e1353ba (patch) | |
tree | 402934b406e63ccd9cff0faec8a83aba6d58abf3 /tensorflow/contrib/lite/kernels/lstm.cc | |
parent | 57d31aa599c83014397a22bbb8f1a27a33b0ade3 (diff) |
Merge the different LSTM EvalFloat/EvalHybrid calls into a single file.
PiperOrigin-RevId: 215870962
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm.cc | 300 |
1 files changed, 28 insertions, 272 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 5b996d00bc..16d67a1a93 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/lstm_eval.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -424,263 +425,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// The LSTM Op engine. -TfLiteStatus EvalFloat( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* activation_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* input_to_input_weights_ptr = - (use_cifg) ? nullptr : input_to_input_weights->data.f; - const float* recurrent_to_input_weights_ptr = - (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; - const float* input_gate_bias_ptr = - (use_cifg) ? nullptr : input_gate_bias->data.f; - const float* cell_to_input_weights_ptr = - (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; - const float* cell_to_forget_weights_ptr = - (use_peephole) ? cell_to_forget_weights->data.f : nullptr; - const float* cell_to_output_weights_ptr = - (use_peephole) ? cell_to_output_weights->data.f : nullptr; - const float* projection_weights_ptr = - (projection_weights == nullptr) ? nullptr : projection_weights->data.f; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const float* input_ptr_batch = input->data.f; - const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; - const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; - const float* input_to_output_weights_ptr = input_to_output_weights->data.f; - const float* recurrent_to_forget_weights_ptr = - recurrent_to_forget_weights->data.f; - const float* recurrent_to_cell_weights_ptr = - recurrent_to_cell_weights->data.f; - const float* recurrent_to_output_weights_ptr = - recurrent_to_output_weights->data.f; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - float* output_ptr_batch = output->data.f; - - kernel_utils::LstmStep( - input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, - input_to_cell_weights_ptr, input_to_output_weights_ptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, - recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, - cell_to_input_weights_ptr, cell_to_forget_weights_ptr, - cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - activation_state_ptr, cell_state_ptr, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); - - return kTfLiteOk; -} - -TfLiteStatus EvalHybrid( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* activation_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - int8_t* input_to_input_weights_ptr = nullptr; - float input_to_input_weights_scale = 1.0f; - int8_t* recurrent_to_input_weights_ptr = nullptr; - float recurrent_to_input_weights_scale = 1.0f; - float* input_gate_bias_ptr = nullptr; - if (!use_cifg) { - input_to_input_weights_ptr = - reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8); - recurrent_to_input_weights_ptr = - reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8); - input_gate_bias_ptr = input_gate_bias->data.f; - input_to_input_weights_scale = input_to_input_weights->params.scale; - recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; - } - - int8_t* cell_to_input_weights_ptr = nullptr; - int8_t* cell_to_forget_weights_ptr = nullptr; - int8_t* cell_to_output_weights_ptr = nullptr; - float cell_to_input_weights_scale = 1.0f; - float cell_to_forget_weights_scale = 1.0f; - float cell_to_output_weights_scale = 1.0f; - if (use_peephole) { - if (!use_cifg) { - cell_to_input_weights_ptr = - reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8); - cell_to_input_weights_scale = cell_to_input_weights->params.scale; - } - cell_to_forget_weights_ptr = - reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8); - cell_to_output_weights_ptr = - reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8); - cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; - cell_to_output_weights_scale = cell_to_output_weights->params.scale; - } - - const int8_t* projection_weights_ptr = - (projection_weights == nullptr) - ? nullptr - : reinterpret_cast<int8_t*>(projection_weights->data.uint8); - const float projection_weights_scale = - (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const float* input_ptr_batch = input->data.f; - const int8_t* input_to_forget_weights_ptr = - reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8); - const float input_to_forget_weights_scale = - input_to_forget_weights->params.scale; - const int8_t* input_to_cell_weights_ptr = - reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8); - const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; - const int8_t* input_to_output_weights_ptr = - reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8); - const float input_to_output_weights_scale = - input_to_output_weights->params.scale; - const int8_t* recurrent_to_forget_weights_ptr = - reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8); - const float recurrent_to_forget_weights_scale = - recurrent_to_forget_weights->params.scale; - const int8_t* recurrent_to_cell_weights_ptr = - reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8); - const float recurrent_to_cell_weights_scale = - recurrent_to_cell_weights->params.scale; - const int8_t* recurrent_to_output_weights_ptr = - reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8); - const float recurrent_to_output_weights_scale = - recurrent_to_output_weights->params.scale; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - float* output_ptr_batch = output->data.f; - - // Temporary storage for quantized values and scaling factors. - int8_t* quantized_input_ptr = - reinterpret_cast<int8_t*>(input_quantized->data.uint8); - int8_t* quantized_activation_state_ptr = - reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8); - int8_t* quantized_cell_state_ptr = - reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8); - float* scaling_factors_ptr = scaling_factors->data.f; - float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; - float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; - - kernel_utils::LstmStep( - input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_cell_weights_ptr, quantized_input_ptr, - quantized_activation_state_ptr, quantized_cell_state_ptr, - activation_state_ptr, cell_state_ptr, output_ptr_batch); - - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); OpData* op_data = reinterpret_cast<OpData*>(node->user_data); @@ -738,15 +482,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): add a check that weights are all uint8s or all floats. switch (input_to_output_weights->type) { case kTfLiteFloat32: { - return EvalFloat(input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, - cell_to_input_weights, cell_to_forget_weights, - cell_to_output_weights, input_gate_bias, - forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, params, - scratch_buffer, activation_state, cell_state, output); + return lstm_eval::EvalFloat( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, params, /*forward_sequence=*/true, + /*output_offset=*/0, scratch_buffer, activation_state, cell_state, + output); } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); @@ -759,17 +509,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/5); TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, /*index=*/6); - return EvalHybrid( + return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, - input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, params, scratch_buffer, - scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, activation_state_quantized, cell_state_quantized, - activation_state, cell_state, output); + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, params, /*forward_sequence=*/true, + /*output_offset=*/0, scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_cell_weights, input_quantized, + /*aux_input_quantized=*/nullptr, activation_state_quantized, + cell_state_quantized, activation_state, cell_state, output); } default: context->ReportError(context, "Type %d is not currently supported.", |