diff options
author | 2018-02-28 13:02:07 -0800 | |
---|---|---|
committer | 2018-02-28 13:08:35 -0800 | |
commit | 69f674b473470b44c6a1ca1bbb3bcc6a8c53074b (patch) | |
tree | 793e26362b2bf184ed07b8c654b3047bf6c6be95 /tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc | |
parent | 757a71e886fb9328b19b0ba15658e49cfa7cc323 (diff) |
Factor out the LstmBatchStep for the various LSTM Ops.
PiperOrigin-RevId: 187370622
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc | 183 |
1 files changed, 11 insertions, 172 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 8d70df5e21..a64ac42bc4 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -443,166 +444,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// Performs an LSTM batch inference step for input specified by input_ptr_batch. -// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and -// biases (*_bias_ptr), and buffers (*_scratch), along with additional -// parameters: -// - params: various LSTM params including activation, clipping, etc., -// - use_cifg: use coupled input forget gates, -// - use_peephole: whether to use peephole connection or not, -// - n_batch: size of batch, -// - n_cell: number of cells (or units), -// - n_input: the input size, -// - n_output: the output size. -// -// The pointers to the hidden state and the output are updated as a result. -// -// The pointers with the suffix "_batch" point to data aligned in batch_major -// order, and each step processes batch_size many inputs from input_ptr_batch, -// and updates batch_size many outputs and hidden states. -void LstmBatchStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - bool use_cifg, bool use_peephole, int n_batch, int n_cell, int n_input, - int n_output, float* output_state_ptr, float* cell_state_ptr, - float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* output_ptr_time) { - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch, - /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch, - /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_time); - } else { - tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, - output_ptr_time, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, - params->proj_clip, output_ptr_time); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_time); - } - tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, - output_state_ptr); -} - // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); @@ -756,7 +597,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output; - LstmBatchStep( + kernel_utils::LstmStep( input_ptr_batch, fw_input_to_input_weights_ptr, fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f, fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr, @@ -766,11 +607,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr, fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f, fw_cell_bias->data.f, fw_output_gate_bias->data.f, - fw_projection_weights_ptr, fw_projection_bias_ptr, params, fw_use_cifg, - fw_use_peephole, n_batch, n_fw_cell, n_input, n_fw_output, - fw_output_state->data.f, fw_cell_state->data.f, fw_input_gate_scratch, - fw_forget_gate_scratch, fw_cell_scratch, fw_output_gate_scratch, - output_ptr_time); + fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch, + n_fw_cell, n_input, n_fw_output, fw_output_state->data.f, + fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch, + fw_cell_scratch, fw_output_gate_scratch, output_ptr_time); } // n_cell and n_output will be the same size when there is no projection. @@ -828,7 +668,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output; - LstmBatchStep( + kernel_utils::LstmStep( input_ptr_batch, bw_input_to_input_weights_ptr, bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f, bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr, @@ -838,11 +678,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr, bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f, bw_cell_bias->data.f, bw_output_gate_bias->data.f, - bw_projection_weights_ptr, bw_projection_bias_ptr, params, bw_use_cifg, - bw_use_peephole, n_batch, n_bw_cell, n_input, n_bw_output, - bw_output_state->data.f, bw_cell_state->data.f, bw_input_gate_scratch, - bw_forget_gate_scratch, bw_cell_scratch, bw_output_gate_scratch, - output_ptr_time); + bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch, + n_bw_cell, n_input, n_bw_output, bw_output_state->data.f, + bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch, + bw_cell_scratch, bw_output_gate_scratch, output_ptr_time); } // Backward step. |