aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc251
1 files changed, 177 insertions, 74 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index c22a457a71..f544dd5ffa 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -114,8 +114,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
const int bw_num_units = bw_input_weights->dims->data[0];
TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]);
@@ -237,8 +240,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize outputs.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
- fw_output_size_array->data[0] = batch_size;
- fw_output_size_array->data[1] = max_time;
+ fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
+ fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
fw_output_size_array->data[2] =
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
@@ -266,8 +269,11 @@ TfLiteStatus EvalFloat(
const TfLiteBidirectionalSequenceRNNParams* params,
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int input_size = input->dims->data[2];
const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
@@ -292,48 +298,91 @@ TfLiteStatus EvalFloat(
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
const int bw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
- for (int b = 0; b < batch_size; b++) {
+ if (time_major) {
+ // TODO(mirkov): add merge_outputs support for time_major inputs.
+ TF_LITE_ASSERT_EQ(params->merge_outputs, false);
+
// Forward cell.
- float* fw_hidden_state_ptr_batch =
- fw_hidden_state->data.f + b * fw_num_units;
- float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
+ float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
+ input->data.f + s * input_size * batch_size;
const float* aux_input_ptr_batch =
(aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
+ ? aux_input->data.f + s * input_size * batch_size
: nullptr;
- float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+ float* output_ptr_batch =
+ fw_output->data.f + s * fw_num_units * batch_size;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
- input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
+ input_size, aux_input_size, fw_num_units, batch_size,
params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
- float* bw_hidden_state_ptr_batch =
- bw_hidden_state->data.f + b * bw_num_units;
- float* bw_output_offset =
- params->merge_outputs
- ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
- : bw_output->data.f + b * bw_output_step * max_time;
+ float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
+ input->data.f + s * input_size * batch_size;
const float* aux_input_ptr_batch =
(aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
+ ? aux_input->data.f + s * input_size * batch_size
: nullptr;
- float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+ float* output_ptr_batch =
+ bw_output->data.f + s * bw_num_units * batch_size;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
- input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
+ input_size, aux_input_size, bw_num_units, batch_size,
params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
+ } else {
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset =
+ fw_output->data.f + b * fw_output_step * max_time;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
+ fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
+ input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
+ params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
+ bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
+ input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
+ params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
}
return kTfLiteOk;
}
@@ -351,8 +400,11 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
TfLiteTensor* bw_output) {
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int input_size = input->dims->data[2];
const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
@@ -403,55 +455,106 @@ TfLiteStatus EvalHybrid(
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
const int bw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
- for (int b = 0; b < batch_size; b++) {
- // Forward cell.
- float* fw_hidden_state_ptr_batch =
- fw_hidden_state->data.f + b * fw_num_units;
- float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
- for (int s = 0; s < max_time; s++) {
- const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
- const float* aux_input_ptr_batch =
- (aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
- : nullptr;
- float* output_ptr_batch = fw_output_offset + s * fw_output_step;
-
- kernel_utils::RnnBatchStep(
- input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
- aux_input_ptr_batch, aux_fw_input_weights_ptr,
- aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
- fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
- fw_num_units, /*batch_size=*/1, params->activation,
- quantized_input_ptr, aux_quantized_input_ptr,
- fw_quantized_hidden_state_ptr, scaling_factors_ptr,
- fw_hidden_state_ptr_batch, output_ptr_batch);
+ if (time_major) {
+ for (int t = 0; t < max_time; t++) {
+ // TODO(mirkov): add merge_outputs support for time_major inputs.
+ TF_LITE_ASSERT_EQ(params->merge_outputs, false);
+
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + s * input_size * batch_size
+ : nullptr;
+ float* output_ptr_batch =
+ fw_output->data.f + s * fw_num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, batch_size, params->activation, quantized_input_ptr,
+ aux_quantized_input_ptr, fw_quantized_hidden_state_ptr,
+ scaling_factors_ptr, fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + s * input_size * batch_size
+ : nullptr;
+ float* output_ptr_batch =
+ bw_output->data.f + s * bw_num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, batch_size, params->activation, quantized_input_ptr,
+ aux_quantized_input_ptr, bw_quantized_hidden_state_ptr,
+ scaling_factors_ptr, bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
}
- // Backward cell.
- float* bw_hidden_state_ptr_batch =
- bw_hidden_state->data.f + b * bw_num_units;
- float* bw_output_offset =
- params->merge_outputs
- ? fw_output->data.f + b * bw_output_step * max_time
- : bw_output->data.f + b * bw_output_step * max_time;
- for (int s = max_time - 1; s >= 0; s--) {
- const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
- const float* aux_input_ptr_batch =
- (aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
- : nullptr;
- float* output_ptr_batch = bw_output_offset + s * bw_output_step;
-
- kernel_utils::RnnBatchStep(
- input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
- aux_input_ptr_batch, aux_bw_input_weights_ptr,
- aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
- bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
- bw_num_units, /*batch_size=*/1, params->activation,
- quantized_input_ptr, aux_quantized_input_ptr,
- bw_quantized_hidden_state_ptr, scaling_factors_ptr,
- bw_hidden_state_ptr_batch, output_ptr_batch);
+ } else {
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset =
+ fw_output->data.f + b * fw_output_step * max_time;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ fw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ bw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
}
}
return kTfLiteOk;