diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-03 13:48:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-03 13:56:35 -0700 |
commit | 89ff1a3d75a93578f633a88e2fe2b2a34b023e52 (patch) | |
tree | 5e9b982e9c1da84cf113ac6ea8e53f361b90e795 /tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc | |
parent | 72cebb88b57396bf74f84bec131c5049974617e7 (diff) |
Update bidirectional sequential LSTM to support state API.
PiperOrigin-RevId: 211378028
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc | 360 |
1 files changed, 292 insertions, 68 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc index a18e1bce34..d058fab529 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -102,10 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel { fw_projection_bias_ = AddNullInput(); } - fw_output_state_ = AddOutput(TensorType_FLOAT32); - fw_cell_state_ = AddOutput(TensorType_FLOAT32); - fw_output_ = AddOutput(TensorType_FLOAT32); - if (use_cifg) { bw_input_to_input_weights_ = AddNullInput(); } else { @@ -161,8 +157,24 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bw_projection_bias_ = AddNullInput(); } - bw_output_state_ = AddOutput(TensorType_FLOAT32); - bw_cell_state_ = AddOutput(TensorType_FLOAT32); + // Adding the 2 input state tensors. + fw_input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}}, + /*is_variable=*/true); + fw_input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}}, + /*is_variable=*/true); + + // Adding the 2 input state tensors. + bw_input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}}, + /*is_variable=*/true); + bw_input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}}, + /*is_variable=*/true); + + fw_output_ = AddOutput(TensorType_FLOAT32); + bw_output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, @@ -259,26 +271,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel { PopulateTensor(bw_projection_bias_, f); } - void ResetFwOutputAndCellStates() { - const int zero_buffer_size = n_fw_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(fw_output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - PopulateTensor(fw_cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetBwOutputAndCellStates() { - const int zero_buffer_size = n_bw_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(bw_output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - PopulateTensor(bw_cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, float* begin, float* end) { PopulateTensor(input_, offset, begin, end); } @@ -340,13 +332,13 @@ class BidirectionalLSTMOpModel : public SingleOpModel { int bw_projection_weights_; int bw_projection_bias_; - int fw_output_; - int fw_output_state_; - int fw_cell_state_; + int fw_input_activation_state_; + int fw_input_cell_state_; + int bw_input_activation_state_; + int bw_input_cell_state_; + int fw_output_; int bw_output_; - int bw_output_state_; - int bw_cell_state_; int n_batch_; int n_input_; @@ -417,6 +409,12 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor }); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, @@ -474,10 +472,6 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { -0.0332076, 0.123838, 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; - // Resetting cell_state and output_state - lstm.ResetFwOutputAndCellStates(); - lstm.ResetBwOutputAndCellStates(); - float* batch0_start = lstm_input; float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); @@ -500,34 +494,151 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); EXPECT_THAT(lstm.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/false, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + // Forward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + // Backward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + // Input should have n_input * sequence_length many values. // Check reversed inputs. static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; + static float lstm_fw_golden_output[] = { + -0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + static float lstm_bw_golden_output[] = { + -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838, + 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; - // Resetting cell_state and output_state - lstm.ResetFwOutputAndCellStates(); - lstm.ResetBwOutputAndCellStates(); - - batch0_start = lstm_input_reversed; - batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + float* batch0_start = lstm_input_reversed; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); lstm.SetInput(0, batch0_start, batch0_end); lstm.Invoke(); - fw_expected.clear(); + std::vector<float> fw_expected; for (int s = 0; s < lstm.sequence_length(); s++) { - fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); - fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); + float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); + float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end); } EXPECT_THAT(lstm.GetBwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); - bw_expected.clear(); + std::vector<float> bw_expected; for (int s = 0; s < lstm.sequence_length(); s++) { - bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); - bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); + float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); + float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end); } EXPECT_THAT(lstm.GetFwOutput(), @@ -592,6 +703,12 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0, 0}, // projection_weight tensor {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor }); lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, @@ -642,10 +759,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577, 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578}; - // Resetting cell_state and output_state - lstm.ResetFwOutputAndCellStates(); - lstm.ResetBwOutputAndCellStates(); - float* batch0_start = lstm_input; float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); @@ -668,34 +781,143 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); EXPECT_THAT(lstm.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} - // Check reversed inputs. - static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; +TEST(LSTMOpTest, + BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; - // Resetting cell_state and output_state - lstm.ResetFwOutputAndCellStates(); - lstm.ResetBwOutputAndCellStates(); + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, + /*use_peephole=*/true, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - batch0_start = lstm_input_reversed; - batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; + static float lstm_fw_golden_output[] = { + -0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}; + static float lstm_bw_golden_output[] = { + -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577, + 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578}; + + float* batch0_start = lstm_input_reversed; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); lstm.SetInput(0, batch0_start, batch0_end); lstm.Invoke(); - fw_expected.clear(); + std::vector<float> fw_expected; for (int s = 0; s < lstm.sequence_length(); s++) { - fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); - fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); + float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); + float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end); } EXPECT_THAT(lstm.GetBwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); - bw_expected.clear(); + std::vector<float> bw_expected; for (int s = 0; s < lstm.sequence_length(); s++) { - bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); - bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); + float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); + float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end); } EXPECT_THAT(lstm.GetFwOutput(), @@ -759,6 +981,12 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {n_output, n_cell}, // projection_weight tensor {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor }); lstm.SetInputToInputWeights( @@ -1343,10 +1571,6 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { 0.065133, 0.024321, 0.038473, 0.062438 }}; - // Resetting cell_state and output_state - lstm.ResetFwOutputAndCellStates(); - lstm.ResetBwOutputAndCellStates(); - for (int i = 0; i < lstm.sequence_length(); i++) { float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); float* batch0_end = batch0_start + lstm.num_inputs(); |