aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-03 13:52:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 14:01:30 -0700
commitec357afc00c8fb0e9c99754afd78d47fc0ab2116 (patch)
tree1ad0ad549bd031ac5ad16f4a894db53bf80d3d9a /tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
parent89ff1a3d75a93578f633a88e2fe2b2a34b023e52 (diff)
Update unidirectional sequential LSTM to support state API.
PiperOrigin-RevId: 211378182
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc73
1 files changed, 29 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index de38bdef6f..cd3aac0532 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -100,8 +100,14 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
+ /*is_variable=*/true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@@ -180,22 +186,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -233,9 +223,10 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
+
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -458,6 +449,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -475,10 +469,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -519,6 +509,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -536,10 +529,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -629,6 +618,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -646,10 +638,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -691,6 +679,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -708,10 +699,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
@@ -1351,6 +1338,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1374,10 +1364,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -1418,6 +1404,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1441,10 +1430,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}