diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 09:38:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 09:42:23 -0700 |
commit | 319da67052b067231d01f46692ce429da7a06f97 (patch) | |
tree | 5fea0d5cdaad4a7bcd259337667e3aab77ee3a40 /tensorflow/contrib/lite/kernels | |
parent | fa1ecc082519922827bad10f07df438c9453fedb (diff) |
Simplify the logic for running through a sequence forwards and backwards.
PiperOrigin-RevId: 214618170
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc | 169 |
1 files changed, 58 insertions, 111 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 541f320138..66b947771c 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -770,51 +770,29 @@ TfLiteStatus EvalFloat( } // Loop through the sequence. - if (forward_sequence) { - for (int t = 0; t < max_time; t++) { - const float* input_ptr = input->data.f + t * n_batch * n_input; - float* output_ptr_time = output->data.f + t * n_batch * n_output; - - kernel_utils::LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, - input_to_forget_weights->data.f, input_to_cell_weights->data.f, - input_to_output_weights->data.f, aux_input_ptr, - aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, - aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, - recurrent_to_cell_weights->data.f, - recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, - cell_to_forget_weights_ptr, cell_to_output_weights_ptr, - input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, - output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, aux_input_size, n_output, - activation_state->data.f, cell_state->data.f, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, - output_ptr_time); - } - } else { - // Loop through the sequence backwards. - for (int t = max_time - 1; t >= 0; t--) { - const float* input_ptr = input->data.f + t * n_batch * n_input; - float* output_ptr_time = output->data.f + t * n_batch * n_output; - - kernel_utils::LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, - input_to_forget_weights->data.f, input_to_cell_weights->data.f, - input_to_output_weights->data.f, aux_input_ptr, - aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, - aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, - recurrent_to_cell_weights->data.f, - recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, - cell_to_forget_weights_ptr, cell_to_output_weights_ptr, - input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, - output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, aux_input_size, n_output, - activation_state->data.f, cell_state->data.f, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, - output_ptr_time); - } + const int input_step = n_batch * n_input; + const int output_step = n_batch * n_output; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + float* output_ptr_time = output->data.f + t_rel * output_step; + + kernel_utils::LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, + input_to_cell_weights->data.f, input_to_output_weights->data.f, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr, + aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, aux_input_size, n_output, + activation_state->data.f, cell_state->data.f, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + output_ptr_time); } return kTfLiteOk; } @@ -991,72 +969,41 @@ TfLiteStatus EvalHybrid( aux_input_to_output_weights_scale = aux_input_to_output_weights->params.scale; } - if (forward_sequence) { - // Feed the sequence into the LSTM step-by-step. - for (int t = 0; t < max_time; t++) { - const float* input_ptr = input->data.f + t * n_batch * n_input; - float* output_ptr = output->data.f + t * n_batch * n_output; - - kernel_utils::LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - aux_input_ptr, aux_input_to_input_weights_ptr, - aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, - aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, - aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, - aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, - recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, - recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, - recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, - recurrent_to_output_weights_scale, cell_to_input_weights_ptr, - cell_to_input_weights_scale, cell_to_forget_weights_ptr, - cell_to_forget_weights_scale, cell_to_output_weights_ptr, - cell_to_output_weights_scale, input_gate_bias_ptr, - forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, - projection_weights_ptr, projection_weights_scale, projection_bias_ptr, - params, n_batch, n_cell, n_input, aux_input_size, n_output, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_cell_weights_ptr, quantized_input_ptr, - quantized_aux_input_ptr, quantized_output_state_ptr, - quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, - output_ptr); - } - } else { - // Loop through the sequence backwards. - for (int t = max_time - 1; t >= 0; t--) { - const float* input_ptr = input->data.f + t * n_batch * n_input; - float* output_ptr = output->data.f + t * n_batch * n_output; - - kernel_utils::LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - aux_input_ptr, aux_input_to_input_weights_ptr, - aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, - aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, - aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, - aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, - recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, - recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, - recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, - recurrent_to_output_weights_scale, cell_to_input_weights_ptr, - cell_to_input_weights_scale, cell_to_forget_weights_ptr, - cell_to_forget_weights_scale, cell_to_output_weights_ptr, - cell_to_output_weights_scale, input_gate_bias_ptr, - forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, - projection_weights_ptr, projection_weights_scale, projection_bias_ptr, - params, n_batch, n_cell, n_input, aux_input_size, n_output, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_cell_weights_ptr, quantized_input_ptr, - quantized_aux_input_ptr, quantized_output_state_ptr, - quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, - output_ptr); - } + + // Feed the sequence into the LSTM step-by-step. + const int input_step = n_batch * n_input; + const int output_step = n_batch * n_output; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + float* output_ptr = output->data.f + t_rel * output_step; + + kernel_utils::LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, + n_input, aux_input_size, n_output, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_aux_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); } return kTfLiteOk; |