aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 13:25:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:32:42 -0700
commitc2c8cfe22492cf7fab804d32283b623632270035 (patch)
tree6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/kernels
parent7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff)
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive. PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc116
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc186
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc85
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc56
4 files changed, 345 insertions, 98 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 66b947771c..0532528f52 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -119,7 +119,7 @@ constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
// Temporary tensors.
enum TemporaryTensor {
@@ -162,7 +162,8 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -347,10 +348,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -368,6 +372,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input);
+ const TfLiteTensor* bw_input_to_output_weights =
+ GetInput(context, node, kBwInputToOutputWeightsTensor);
+ const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
+ n_input);
+
const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
@@ -375,6 +386,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_fw_cell);
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
+ const TfLiteTensor* bw_recurrent_to_output_weights =
+ GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
+ n_bw_cell);
+ const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
+
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
@@ -440,7 +458,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
- fw_output_size->data[2] = n_fw_output;
+ fw_output_size->data[2] =
+ params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
@@ -479,39 +498,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
fw_scratch_buffer_size));
// Same for the backward cell.
- const TfLiteTensor* bw_input_to_output_weights =
- GetInput(context, node, kBwInputToOutputWeightsTensor);
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
- n_input);
-
- const TfLiteTensor* bw_recurrent_to_output_weights =
- GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
- n_bw_cell);
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
n_bw_cell));
- // Get the pointer to output, activation_state and cell_state buffer tensors.
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ // Get the pointer to activation_state and cell_state buffer tensors.
TfLiteTensor* bw_activation_state =
GetVariableInput(context, node, kBwInputActivationStateTensor);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
// Resize the output tensors.
- TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
- bw_output_size->data[0] = max_time;
- bw_output_size->data[1] = n_batch;
- bw_output_size->data[2] = n_bw_output;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, bw_output, bw_output_size));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
+ bw_output_size->data[0] = max_time;
+ bw_output_size->data[1] = n_batch;
+ bw_output_size->data[2] = n_bw_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_output, bw_output_size));
+ }
// Check the shape of input state tensors.
// These tensor may be 1D or 2D. It's fine as long as the total size is
@@ -705,7 +713,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
const int max_time = input->dims->data[0];
@@ -771,12 +779,13 @@ TfLiteStatus EvalFloat(
// Loop through the sequence.
const int input_step = n_batch * n_input;
- const int output_step = n_batch * n_output;
+ const int output_step = n_batch * output->dims->data[2];
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr_time = output->data.f + t_rel * output_step;
+ float* output_ptr_time =
+ output->data.f + t_rel * output_step + output_offset;
kernel_utils::LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
@@ -816,7 +825,7 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
@@ -972,12 +981,12 @@ TfLiteStatus EvalHybrid(
// Feed the sequence into the LSTM step-by-step.
const int input_step = n_batch * n_input;
- const int output_step = n_batch * n_output;
+ const int output_step = n_batch * output->dims->data[2];
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr = output->data.f + t_rel * output_step;
+ float* output_ptr = output->data.f + t_rel * output_step + output_offset;
kernel_utils::LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
@@ -1011,7 +1020,8 @@ TfLiteStatus EvalHybrid(
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -1107,7 +1117,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwInputActivationStateTensor);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
// Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
@@ -1135,6 +1147,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_aux_input_to_output_weights =
GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+ // Populate a TfLiteLSTMParams struct for the evaluation functions.
+ TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
+ params->proj_clip, kTfLiteLSTMFullKernel};
+
+ const int bw_output_offset =
+ params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
+ const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
+
switch (fw_input_to_output_weights->type) {
case kTfLiteFloat32: {
TfLiteStatus fw_pass_status = EvalFloat(
@@ -1147,9 +1167,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
- fw_projection_weights, fw_projection_bias, params,
- /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state,
- fw_cell_state, fw_output);
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
TfLiteStatus bw_pass_status = EvalFloat(
@@ -1162,9 +1182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
- bw_projection_weights, bw_projection_bias, params,
- /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state,
- bw_cell_state, bw_output);
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ bw_activation_state, bw_cell_state, actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}
@@ -1198,10 +1218,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
- fw_projection_weights, fw_projection_bias, params,
- /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_cell_weights, input_quantized,
- aux_input_quantized, fw_activation_state_quantized,
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, fw_activation_state_quantized,
fw_cell_state_quantized, fw_activation_state, fw_cell_state,
fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
@@ -1216,12 +1236,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
- bw_projection_weights, bw_projection_bias, params,
- /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_cell_weights, input_quantized,
- aux_input_quantized, bw_activation_state_quantized,
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, bw_activation_state_quantized,
bw_cell_state_quantized, bw_activation_state, bw_cell_state,
- bw_output);
+ actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index 74ba8021c2..9cc04907e1 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -35,8 +35,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
int sequence_length, bool use_cifg,
bool use_peephole, bool use_projection_weights,
- bool use_projection_bias, float cell_clip,
- float proj_clip,
+ bool use_projection_bias, bool merge_outputs,
+ float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes)
: n_batch_(n_batch),
n_input_(n_input),
@@ -175,7 +175,9 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
aux_input_ = AddNullInput();
fw_aux_input_to_input_weights_ = AddNullInput();
@@ -188,9 +190,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_aux_input_to_output_weights_ = AddNullInput();
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ CreateBidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip,
+ proj_clip, merge_outputs)
.Union());
BuildInterpreter(input_shapes);
}
@@ -380,7 +383,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -526,6 +530,162 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with a single merged output tensor.
+TEST(LSTMOpTest, BlackBoxTestMergedOutput) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
+
+ float* batch0_start = lstm_input;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int k = 0; k < lstm.sequence_length(); k++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_fw_golden_output + k * lstm.num_fw_outputs(),
+ lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs());
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_bw_golden_output + k * lstm.num_bw_outputs(),
+ lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs());
+ }
+ EXPECT_THAT(lstm.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
const int n_batch = 1;
const int n_input = 2;
@@ -537,7 +697,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -696,7 +857,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -845,7 +1007,8 @@ TEST(LSTMOpTest,
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -994,7 +1157,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/true, /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 2f896c5289..9f62ac3f2c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -47,7 +47,7 @@ constexpr int kFwAuxWeightsTensor = 10; // Optional.
constexpr int kBwAuxWeightsTensor = 11; // Optional.
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
// Temporary tensors.
enum TemporaryTensor {
@@ -70,9 +70,13 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -142,9 +146,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_weights->dims->data[1]);
}
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
-
const bool is_hybrid_op =
(fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
@@ -233,18 +234,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Resize outputs.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = batch_size;
fw_output_size_array->data[1] = max_time;
- fw_output_size_array->data[2] = fw_num_units;
+ fw_output_size_array->data[2] =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
- TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
- bw_output_size_array->data[0] = batch_size;
- bw_output_size_array->data[1] = max_time;
- bw_output_size_array->data[2] = bw_num_units;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_output, bw_output_size_array));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
+ bw_output_size_array->data[0] = batch_size;
+ bw_output_size_array->data[1] = max_time;
+ bw_output_size_array->data[2] = bw_num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
+ bw_output_size_array));
+ }
return kTfLiteOk;
}
@@ -256,9 +262,9 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
const TfLiteTensor* bw_aux_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
- TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -281,10 +287,15 @@ TfLiteStatus EvalFloat(
? bw_aux_input_weights->data.f
: nullptr;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -292,8 +303,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
@@ -304,6 +314,10 @@ TfLiteStatus EvalFloat(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -311,8 +325,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
@@ -331,11 +344,12 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
const TfLiteTensor* aux_bw_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
- TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
+ TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -384,10 +398,15 @@ TfLiteStatus EvalHybrid(
reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -395,8 +414,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
@@ -411,6 +429,10 @@ TfLiteStatus EvalHybrid(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -418,8 +440,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
@@ -436,8 +457,8 @@ TfLiteStatus EvalHybrid(
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params =
- reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -465,7 +486,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwHiddenStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
switch (fw_input_weights->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 3e34ba6196..f555c472f5 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size)
+ int bw_units, int input_size, bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -675,12 +675,15 @@ class BidirectionalRNNOpModel : public SingleOpModel {
aux_bw_weights_ = AddNullInput();
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_SequenceRNNOptions,
- CreateSequenceRNNOptions(builder_, /*time_major=*/false,
- ActivationFunctionType_RELU)
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, /*time_major=*/false,
+ ActivationFunctionType_RELU, merge_outputs)
.Union());
BuildInterpreter({
{batches_, sequence_len_, input_size_}, // input
@@ -767,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with merged outputs.
+TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*merge_outputs=*/true);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ float* batch_start = rnn_input;
+ float* batch_end = batch_start + input_sequence_size;
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(input_sequence_size, batch_start, batch_end);
+
+ rnn.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int bid = 0; bid < rnn.num_batches(); bid++) {
+ for (int step = 0; step < rnn.sequence_len(); step++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_fw_output + rnn.num_fw_units() * step,
+ rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_bw_output + rnn.num_bw_units() * step,
+ rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
+ }
+ }
+ EXPECT_THAT(rnn.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
// Check that if the input sequence is reversed the outputs are the same just
// forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -851,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,