diff options
author | Alan Chiao <alanchiao@google.com> | 2018-08-23 13:42:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 13:45:36 -0700 |
commit | 9a774e4d2d31443ea694938bec41237b4d6bcf02 (patch) | |
tree | eba93b9e5a8c96c5c40b2aff448e01690eb8a0db /tensorflow/contrib/lite/kernels/lstm_test.cc | |
parent | 6fe361f80d4277ea879b3182e1d7148a65a8ca21 (diff) |
Remove 18-input/3-output LSTM in favor of 20-input/1-output LSTM that supports
state API.
PiperOrigin-RevId: 209991722
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm_test.cc | 43 |
1 files changed, 1 insertions, 42 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 0266f5fe57..dc128105d3 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) .Union()); + BuildInterpreter(input_shapes); } @@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast<float*>(begin), const_cast<float*>(end)); @@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); } @@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -698,10 +669,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } @@ -1362,10 +1329,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } |