aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 13:25:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:32:42 -0700
commitc2c8cfe22492cf7fab804d32283b623632270035 (patch)
tree6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
parent7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff)
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive. PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc85
1 files changed, 54 insertions, 31 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 2f896c5289..9f62ac3f2c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -47,7 +47,7 @@ constexpr int kFwAuxWeightsTensor = 10; // Optional.
constexpr int kBwAuxWeightsTensor = 11; // Optional.
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
// Temporary tensors.
enum TemporaryTensor {
@@ -70,9 +70,13 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -142,9 +146,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_weights->dims->data[1]);
}
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
-
const bool is_hybrid_op =
(fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
@@ -233,18 +234,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Resize outputs.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = batch_size;
fw_output_size_array->data[1] = max_time;
- fw_output_size_array->data[2] = fw_num_units;
+ fw_output_size_array->data[2] =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
- TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
- bw_output_size_array->data[0] = batch_size;
- bw_output_size_array->data[1] = max_time;
- bw_output_size_array->data[2] = bw_num_units;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_output, bw_output_size_array));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
+ bw_output_size_array->data[0] = batch_size;
+ bw_output_size_array->data[1] = max_time;
+ bw_output_size_array->data[2] = bw_num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
+ bw_output_size_array));
+ }
return kTfLiteOk;
}
@@ -256,9 +262,9 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
const TfLiteTensor* bw_aux_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
- TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -281,10 +287,15 @@ TfLiteStatus EvalFloat(
? bw_aux_input_weights->data.f
: nullptr;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -292,8 +303,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
@@ -304,6 +314,10 @@ TfLiteStatus EvalFloat(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -311,8 +325,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
@@ -331,11 +344,12 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
const TfLiteTensor* aux_bw_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
- TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
+ TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -384,10 +398,15 @@ TfLiteStatus EvalHybrid(
reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -395,8 +414,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
@@ -411,6 +429,10 @@ TfLiteStatus EvalHybrid(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -418,8 +440,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
@@ -436,8 +457,8 @@ TfLiteStatus EvalHybrid(
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params =
- reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -465,7 +486,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwHiddenStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
switch (fw_input_weights->type) {
case kTfLiteFloat32: