aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/basic_rnn.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-02 19:40:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 19:44:35 -0800
commit39010bef7f72709a87a275060878baac815744c2 (patch)
tree5c38e0552002368e4942a986b847ab638078bcd5 /tensorflow/contrib/lite/kernels/basic_rnn.cc
parent0fab6e888c5f90de3e878566123c1906261ce27e (diff)
A more efficient implementation of the Op using batch operations.
PiperOrigin-RevId: 184367562
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc57
1 files changed, 14 insertions, 43 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index a0391e030f..2c5074eca3 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1];
- const int input_weights_stride = input_weights->dims->data[1];
- const int recurrent_weights_stride = recurrent_weights->dims->data[1];
-
- // For each batch
- for (int b = 0; b < batch_size; b++) {
- // Initialize the pointer to input, output and bias.
- const float* input_ptr_batch = input->data.f + b * input_size;
- float* output_ptr_batch = output->data.f + b * num_units;
- float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
-
- // Initialize input_weights and recurrent_weights.
- const float* input_weights_ptr = input_weights->data.f;
- const float* recurrent_weights_ptr = recurrent_weights->data.f;
-
- // Output = bias
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = bias_ptr[o];
- }
-
- // Output += input * input_weights
- for (int o = 0; o < num_units; o++) {
- for (int i = 0; i < input_size; i++) {
- output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
- }
- input_weights_ptr += input_weights_stride;
- }
-
- // Output += recurrent_weights * hidden_state
- for (int o = 0; o < num_units; o++) {
- for (int h = 0; h < num_units; h++) {
- output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
- }
- recurrent_weights_ptr += recurrent_weights_stride;
- }
-
- // Output = activation(Output) and update hidden_state
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] =
- (ActivationFunctor(params->activation))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
- }
- }
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f;
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch = input->data.f;
+ float* output_ptr_batch = output->data.f;
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
+
+ kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ recurrent_weights_ptr, bias_ptr, input_size,
+ num_units, batch_size, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
return kTfLiteOk;
}