aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 09:38:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 09:42:23 -0700
commit319da67052b067231d01f46692ce429da7a06f97 (patch)
tree5fea0d5cdaad4a7bcd259337667e3aab77ee3a40 /tensorflow/contrib/lite/kernels
parentfa1ecc082519922827bad10f07df438c9453fedb (diff)
Simplify the logic for running through a sequence forwards and backwards.
PiperOrigin-RevId: 214618170
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc169
1 files changed, 58 insertions, 111 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 541f320138..66b947771c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -770,51 +770,29 @@ TfLiteStatus EvalFloat(
}
// Loop through the sequence.
- if (forward_sequence) {
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr,
- input_to_forget_weights->data.f, input_to_cell_weights->data.f,
- input_to_output_weights->data.f, aux_input_ptr,
- aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
- aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
- recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
- } else {
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr,
- input_to_forget_weights->data.f, input_to_cell_weights->data.f,
- input_to_output_weights->data.f, aux_input_ptr,
- aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
- aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
- recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * n_output;
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time = output->data.f + t_rel * output_step;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
}
return kTfLiteOk;
}
@@ -991,72 +969,41 @@ TfLiteStatus EvalHybrid(
aux_input_to_output_weights_scale =
aux_input_to_output_weights->params.scale;
}
- if (forward_sequence) {
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr,
- forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
- projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr);
- }
- } else {
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr,
- forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
- projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr);
- }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * n_output;
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
}
return kTfLiteOk;