aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm_test.cc
diff options
context:
space:
mode:
authorGravatar Alan Chiao <alanchiao@google.com>2018-08-23 13:42:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 13:45:36 -0700
commit9a774e4d2d31443ea694938bec41237b4d6bcf02 (patch)
treeeba93b9e5a8c96c5c40b2aff448e01690eb8a0db /tensorflow/contrib/lite/kernels/lstm_test.cc
parent6fe361f80d4277ea879b3182e1d7148a65a8ca21 (diff)
Remove 18-input/3-output LSTM in favor of 20-input/1-output LSTM that supports
state API. PiperOrigin-RevId: 209991722
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc43
1 files changed, 1 insertions, 42 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 0266f5fe57..dc128105d3 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
+
BuildInterpreter(input_shapes);
}
@@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -698,10 +669,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
@@ -1362,10 +1329,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}