aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-28 13:02:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 13:08:35 -0800
commit69f674b473470b44c6a1ca1bbb3bcc6a8c53074b (patch)
tree793e26362b2bf184ed07b8c654b3047bf6c6be95 /tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
parent757a71e886fb9328b19b0ba15658e49cfa7cc323 (diff)
Factor out the LstmBatchStep for the various LSTM Ops.
PiperOrigin-RevId: 187370622
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc183
1 files changed, 11 insertions, 172 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 8d70df5e21..a64ac42bc4 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -443,166 +444,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// Performs an LSTM batch inference step for input specified by input_ptr_batch.
-// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
-// biases (*_bias_ptr), and buffers (*_scratch), along with additional
-// parameters:
-// - params: various LSTM params including activation, clipping, etc.,
-// - use_cifg: use coupled input forget gates,
-// - use_peephole: whether to use peephole connection or not,
-// - n_batch: size of batch,
-// - n_cell: number of cells (or units),
-// - n_input: the input size,
-// - n_output: the output size.
-//
-// The pointers to the hidden state and the output are updated as a result.
-//
-// The pointers with the suffix "_batch" point to data aligned in batch_major
-// order, and each step processes batch_size many inputs from input_ptr_batch,
-// and updates batch_size many outputs and hidden states.
-void LstmBatchStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- bool use_cifg, bool use_peephole, int n_batch, int n_cell, int n_input,
- int n_output, float* output_state_ptr, float* cell_state_ptr,
- float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* output_ptr_time) {
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch,
- /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_time);
- } else {
- tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_time, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_time, n_batch * n_output,
- params->proj_clip, output_ptr_time);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_time);
- }
- tensor_utils::CopyVector(output_ptr_time, n_batch * n_output,
- output_state_ptr);
-}
-
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
@@ -756,7 +597,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output;
- LstmBatchStep(
+ kernel_utils::LstmStep(
input_ptr_batch, fw_input_to_input_weights_ptr,
fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f,
fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr,
@@ -766,11 +607,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr,
fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
fw_cell_bias->data.f, fw_output_gate_bias->data.f,
- fw_projection_weights_ptr, fw_projection_bias_ptr, params, fw_use_cifg,
- fw_use_peephole, n_batch, n_fw_cell, n_input, n_fw_output,
- fw_output_state->data.f, fw_cell_state->data.f, fw_input_gate_scratch,
- fw_forget_gate_scratch, fw_cell_scratch, fw_output_gate_scratch,
- output_ptr_time);
+ fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
+ n_fw_cell, n_input, n_fw_output, fw_output_state->data.f,
+ fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
+ fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
}
// n_cell and n_output will be the same size when there is no projection.
@@ -828,7 +668,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output;
- LstmBatchStep(
+ kernel_utils::LstmStep(
input_ptr_batch, bw_input_to_input_weights_ptr,
bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f,
bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr,
@@ -838,11 +678,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr,
bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
bw_cell_bias->data.f, bw_output_gate_bias->data.f,
- bw_projection_weights_ptr, bw_projection_bias_ptr, params, bw_use_cifg,
- bw_use_peephole, n_batch, n_bw_cell, n_input, n_bw_output,
- bw_output_state->data.f, bw_cell_state->data.f, bw_input_gate_scratch,
- bw_forget_gate_scratch, bw_cell_scratch, bw_output_gate_scratch,
- output_ptr_time);
+ bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
+ n_bw_cell, n_input, n_bw_output, bw_output_state->data.f,
+ bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
+ bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
}
// Backward step.