aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-28 13:02:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 13:08:35 -0800
commit69f674b473470b44c6a1ca1bbb3bcc6a8c53074b (patch)
tree793e26362b2bf184ed07b8c654b3047bf6c6be95 /tensorflow/contrib/lite/kernels/lstm.cc
parent757a71e886fb9328b19b0ba15658e49cfa7cc323 (diff)
Factor out the LstmBatchStep for the various LSTM Ops.
PiperOrigin-RevId: 187370622
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc170
1 files changed, 49 insertions, 121 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 6c06264d84..b9255b23a5 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -377,127 +378,54 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
}
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell,
- n_batch, input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell,
- n_batch, output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights->data.f, n_cell, n_output,
- output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights->data.f, n_cell, n_output,
- output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights->data.f, n_cell, n_output,
- output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
- cell_state->data.f, n_batch * n_cell,
- cell_state->data.f);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell,
- cell_state->data.f);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell,
- params->cell_clip, cell_state->data.f);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights != nullptr);
- const bool use_projection_bias = (projection_bias != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output,
- n_batch, output->data.f);
- } else {
- tensor_utils::ZeroVector(output->data.f, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights->data.f, n_output, n_cell, output_gate_scratch,
- n_batch, output->data.f, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output->data.f, n_batch * n_output,
- params->proj_clip, output->data.f);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output->data.f);
- }
- tensor_utils::CopyVector(output->data.f, n_batch * n_output,
- output_state->data.f);
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
+ const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
+ const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
+ const float* recurrent_to_forget_weights_ptr =
+ recurrent_to_forget_weights->data.f;
+ const float* recurrent_to_cell_weights_ptr =
+ recurrent_to_cell_weights->data.f;
+ const float* recurrent_to_output_weights_ptr =
+ recurrent_to_output_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ kernel_utils::LstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
+ output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, output_ptr_batch);
return kTfLiteOk;
}