diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc | 251 |
1 files changed, 177 insertions, 74 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index c22a457a71..f544dd5ffa 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -114,8 +114,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input->dims->size, 3); - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int fw_num_units = fw_input_weights->dims->data[0]; const int bw_num_units = bw_input_weights->dims->data[0]; TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]); @@ -237,8 +240,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Resize outputs. TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); - fw_output_size_array->data[0] = batch_size; - fw_output_size_array->data[1] = max_time; + fw_output_size_array->data[0] = (time_major) ? max_time : batch_size; + fw_output_size_array->data[1] = (time_major) ? batch_size : max_time; fw_output_size_array->data[2] = params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; TF_LITE_ENSURE_OK( @@ -266,8 +269,11 @@ TfLiteStatus EvalFloat( const TfLiteBidirectionalSequenceRNNParams* params, TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int input_size = input->dims->data[2]; const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; @@ -292,48 +298,91 @@ TfLiteStatus EvalFloat( params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; const int bw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; - for (int b = 0; b < batch_size; b++) { + if (time_major) { + // TODO(mirkov): add merge_outputs support for time_major inputs. + TF_LITE_ASSERT_EQ(params->merge_outputs, false); + // Forward cell. - float* fw_hidden_state_ptr_batch = - fw_hidden_state->data.f + b * fw_num_units; - float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; + float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f; for (int s = 0; s < max_time; s++) { const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; + input->data.f + s * input_size * batch_size; const float* aux_input_ptr_batch = (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size + ? aux_input->data.f + s * input_size * batch_size : nullptr; - float* output_ptr_batch = fw_output_offset + s * fw_output_step; + float* output_ptr_batch = + fw_output->data.f + s * fw_num_units * batch_size; kernel_utils::RnnBatchStep( input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, - input_size, aux_input_size, fw_num_units, /*batch_size=*/1, + input_size, aux_input_size, fw_num_units, batch_size, params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); } // Backward cell. - float* bw_hidden_state_ptr_batch = - bw_hidden_state->data.f + b * bw_num_units; - float* bw_output_offset = - params->merge_outputs - ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units - : bw_output->data.f + b * bw_output_step * max_time; + float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; + input->data.f + s * input_size * batch_size; const float* aux_input_ptr_batch = (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size + ? aux_input->data.f + s * input_size * batch_size : nullptr; - float* output_ptr_batch = bw_output_offset + s * bw_output_step; + float* output_ptr_batch = + bw_output->data.f + s * bw_num_units * batch_size; kernel_utils::RnnBatchStep( input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, - input_size, aux_input_size, bw_num_units, /*batch_size=*/1, + input_size, aux_input_size, bw_num_units, batch_size, params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); } + } else { + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = + fw_output->data.f + b * fw_output_step * max_time; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, + fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, + input_size, aux_input_size, fw_num_units, /*batch_size=*/1, + params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units + : bw_output->data.f + b * bw_output_step * max_time; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, + bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, + input_size, aux_input_size, bw_num_units, /*batch_size=*/1, + params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); + } + } } return kTfLiteOk; } @@ -351,8 +400,11 @@ TfLiteStatus EvalHybrid( TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int input_size = input->dims->data[2]; const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; @@ -403,55 +455,106 @@ TfLiteStatus EvalHybrid( params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; const int bw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; - for (int b = 0; b < batch_size; b++) { - // Forward cell. - float* fw_hidden_state_ptr_batch = - fw_hidden_state->data.f + b * fw_num_units; - float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; - for (int s = 0; s < max_time; s++) { - const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; - const float* aux_input_ptr_batch = - (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size - : nullptr; - float* output_ptr_batch = fw_output_offset + s * fw_output_step; - - kernel_utils::RnnBatchStep( - input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, - aux_input_ptr_batch, aux_fw_input_weights_ptr, - aux_fw_input_weights_scale, fw_recurrent_weights_ptr, - fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, - fw_num_units, /*batch_size=*/1, params->activation, - quantized_input_ptr, aux_quantized_input_ptr, - fw_quantized_hidden_state_ptr, scaling_factors_ptr, - fw_hidden_state_ptr_batch, output_ptr_batch); + if (time_major) { + for (int t = 0; t < max_time; t++) { + // TODO(mirkov): add merge_outputs support for time_major inputs. + TF_LITE_ASSERT_EQ(params->merge_outputs, false); + + // Forward cell. + float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + s * input_size * batch_size + : nullptr; + float* output_ptr_batch = + fw_output->data.f + s * fw_num_units * batch_size; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, + aux_input_ptr_batch, aux_fw_input_weights_ptr, + aux_fw_input_weights_scale, fw_recurrent_weights_ptr, + fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, + fw_num_units, batch_size, params->activation, quantized_input_ptr, + aux_quantized_input_ptr, fw_quantized_hidden_state_ptr, + scaling_factors_ptr, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + s * input_size * batch_size + : nullptr; + float* output_ptr_batch = + bw_output->data.f + s * bw_num_units * batch_size; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, + aux_input_ptr_batch, aux_bw_input_weights_ptr, + aux_bw_input_weights_scale, bw_recurrent_weights_ptr, + bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, + bw_num_units, batch_size, params->activation, quantized_input_ptr, + aux_quantized_input_ptr, bw_quantized_hidden_state_ptr, + scaling_factors_ptr, bw_hidden_state_ptr_batch, output_ptr_batch); + } } - // Backward cell. - float* bw_hidden_state_ptr_batch = - bw_hidden_state->data.f + b * bw_num_units; - float* bw_output_offset = - params->merge_outputs - ? fw_output->data.f + b * bw_output_step * max_time - : bw_output->data.f + b * bw_output_step * max_time; - for (int s = max_time - 1; s >= 0; s--) { - const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; - const float* aux_input_ptr_batch = - (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size - : nullptr; - float* output_ptr_batch = bw_output_offset + s * bw_output_step; - - kernel_utils::RnnBatchStep( - input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, - aux_input_ptr_batch, aux_bw_input_weights_ptr, - aux_bw_input_weights_scale, bw_recurrent_weights_ptr, - bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, - bw_num_units, /*batch_size=*/1, params->activation, - quantized_input_ptr, aux_quantized_input_ptr, - bw_quantized_hidden_state_ptr, scaling_factors_ptr, - bw_hidden_state_ptr_batch, output_ptr_batch); + } else { + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = + fw_output->data.f + b * fw_output_step * max_time; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, + aux_input_ptr_batch, aux_fw_input_weights_ptr, + aux_fw_input_weights_scale, fw_recurrent_weights_ptr, + fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, + fw_num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, aux_quantized_input_ptr, + fw_quantized_hidden_state_ptr, scaling_factors_ptr, + fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + : bw_output->data.f + b * bw_output_step * max_time; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, + aux_input_ptr_batch, aux_bw_input_weights_ptr, + aux_bw_input_weights_scale, bw_recurrent_weights_ptr, + bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, + bw_num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, aux_quantized_input_ptr, + bw_quantized_hidden_state_ptr, scaling_factors_ptr, + bw_hidden_state_ptr_batch, output_ptr_batch); + } } } return kTfLiteOk; |