diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-28 13:02:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-28 13:08:35 -0800 |
commit | 69f674b473470b44c6a1ca1bbb3bcc6a8c53074b (patch) | |
tree | 793e26362b2bf184ed07b8c654b3047bf6c6be95 /tensorflow/contrib/lite/kernels/lstm.cc | |
parent | 757a71e886fb9328b19b0ba15658e49cfa7cc323 (diff) |
Factor out the LstmBatchStep for the various LSTM Ops.
PiperOrigin-RevId: 187370622
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm.cc | 170 |
1 files changed, 49 insertions, 121 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 6c06264d84..b9255b23a5 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -377,127 +378,54 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; } - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, - n_batch, output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights->data.f, n_cell, n_output, - output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, - cell_state->data.f, n_batch * n_cell, - cell_state->data.f); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, - cell_state->data.f); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, - params->cell_clip, cell_state->data.f); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights != nullptr); - const bool use_projection_bias = (projection_bias != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, - n_batch, output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights->data.f, n_output, n_cell, output_gate_scratch, - n_batch, output->data.f, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output->data.f, n_batch * n_output, - params->proj_clip, output->data.f); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output->data.f); - } - tensor_utils::CopyVector(output->data.f, n_batch * n_output, - output_state->data.f); + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; + const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; + const float* input_to_output_weights_ptr = input_to_output_weights->data.f; + const float* recurrent_to_forget_weights_ptr = + recurrent_to_forget_weights->data.f; + const float* recurrent_to_cell_weights_ptr = + recurrent_to_cell_weights->data.f; + const float* recurrent_to_output_weights_ptr = + recurrent_to_output_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + cell_to_input_weights_ptr, cell_to_forget_weights_ptr, + cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_batch); return kTfLiteOk; } |