diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-03 13:25:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 13:32:42 -0700 |
commit | c2c8cfe22492cf7fab804d32283b623632270035 (patch) | |
tree | 6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc | |
parent | 7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff) |
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive.
PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc | 85 |
1 files changed, 54 insertions, 31 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index 2f896c5289..9f62ac3f2c 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -47,7 +47,7 @@ constexpr int kFwAuxWeightsTensor = 10; // Optional. constexpr int kBwAuxWeightsTensor = 11; // Optional. // Output tensors. constexpr int kFwOutputTensor = 0; -constexpr int kBwOutputTensor = 1; +constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false. // Temporary tensors. enum TemporaryTensor { @@ -70,9 +70,13 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( + node->builtin_data); + // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 12); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, + params->merge_outputs ? 1 : 2); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* fw_input_weights = @@ -142,9 +146,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_weights->dims->data[1]); } - TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); - const bool is_hybrid_op = (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32); @@ -233,18 +234,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Resize outputs. + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); fw_output_size_array->data[0] = batch_size; fw_output_size_array->data[1] = max_time; - fw_output_size_array->data[2] = fw_num_units; + fw_output_size_array->data[2] = + params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, fw_output, fw_output_size_array)); - TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); - bw_output_size_array->data[0] = batch_size; - bw_output_size_array->data[1] = max_time; - bw_output_size_array->data[2] = bw_num_units; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, bw_output, bw_output_size_array)); + if (!params->merge_outputs) { + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); + bw_output_size_array->data[0] = batch_size; + bw_output_size_array->data[1] = max_time; + bw_output_size_array->data[2] = bw_num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output, + bw_output_size_array)); + } return kTfLiteOk; } @@ -256,9 +262,9 @@ TfLiteStatus EvalFloat( const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights, const TfLiteTensor* bw_aux_input_weights, - const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state, - TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state, - TfLiteTensor* bw_output) { + const TfLiteBidirectionalSequenceRNNParams* params, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { const int batch_size = input->dims->data[0]; const int max_time = input->dims->data[1]; const int input_size = input->dims->data[2]; @@ -281,10 +287,15 @@ TfLiteStatus EvalFloat( ? bw_aux_input_weights->data.f : nullptr; + const int fw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; + const int bw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; for (int b = 0; b < batch_size; b++) { // Forward cell. float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; for (int s = 0; s < max_time; s++) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -292,8 +303,7 @@ TfLiteStatus EvalFloat( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, @@ -304,6 +314,10 @@ TfLiteStatus EvalFloat( // Backward cell. float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units + : bw_output->data.f + b * bw_output_step * max_time; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -311,8 +325,7 @@ TfLiteStatus EvalFloat( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, @@ -331,11 +344,12 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights, const TfLiteTensor* aux_bw_input_weights, - const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors, - TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, - TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state, - TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized, - TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { + const TfLiteBidirectionalSequenceRNNParams* params, + TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized, + TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, + TfLiteTensor* bw_output) { const int batch_size = input->dims->data[0]; const int max_time = input->dims->data[1]; const int input_size = input->dims->data[2]; @@ -384,10 +398,15 @@ TfLiteStatus EvalHybrid( reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8); float* scaling_factors_ptr = scaling_factors->data.f; + const int fw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; + const int bw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; for (int b = 0; b < batch_size; b++) { // Forward cell. float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; for (int s = 0; s < max_time; s++) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -395,8 +414,7 @@ TfLiteStatus EvalHybrid( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, @@ -411,6 +429,10 @@ TfLiteStatus EvalHybrid( // Backward cell. float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + : bw_output->data.f + b * bw_output_step * max_time; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -418,8 +440,7 @@ TfLiteStatus EvalHybrid( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, @@ -436,8 +457,8 @@ TfLiteStatus EvalHybrid( } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const auto* params = - reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( + node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* fw_input_weights = @@ -465,7 +486,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kBwHiddenStateTensor); TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteTensor* bw_output = params->merge_outputs + ? nullptr + : GetOutput(context, node, kBwOutputTensor); switch (fw_input_weights->type) { case kTfLiteFloat32: |