diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm_test.cc | 49 |
1 files changed, 4 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 0266f5fe57..e7ddfceb45 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) .Union()); + BuildInterpreter(input_shapes); } @@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast<float*>(begin), const_cast<float*>(end)); @@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); } @@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -698,14 +669,10 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { +class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest { void SetUp() override { input_to_input_weights_ = { 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, @@ -1304,7 +1271,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { } }; -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1362,14 +1329,10 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } |