diff options
author | 2018-08-29 07:41:42 -0700 | |
---|---|---|
committer | 2018-08-29 07:45:47 -0700 | |
commit | d5d02f078ff8d5f4c5541c9281e1a0e027ce9f0c (patch) | |
tree | b1637242007789afdb393dbd97edf117219709cc /tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc | |
parent | 8dfb7532f8278e53a86a847ba6aa9c441f7b021b (diff) |
Update bidirectional RNN to support state API.
PiperOrigin-RevId: 210719446
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc | 26 |
1 files changed, 5 insertions, 21 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc index 911b108eaa..03236dbcdc 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -664,12 +664,12 @@ class BidirectionalRNNOpModel : public SingleOpModel { fw_weights_ = AddInput(TensorType_FLOAT32); fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); fw_bias_ = AddInput(TensorType_FLOAT32); - fw_hidden_state_ = AddOutput(TensorType_FLOAT32); + fw_hidden_state_ = AddInput(TensorType_FLOAT32, true); fw_output_ = AddOutput(TensorType_FLOAT32); bw_weights_ = AddInput(TensorType_FLOAT32); bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); bw_bias_ = AddInput(TensorType_FLOAT32); - bw_hidden_state_ = AddOutput(TensorType_FLOAT32); + bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); bw_output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_SequenceRNNOptions, @@ -681,9 +681,11 @@ class BidirectionalRNNOpModel : public SingleOpModel { {fw_units_, input_size_}, // fw_weights {fw_units_, fw_units_}, // fw_recurrent_weights {fw_units_}, // fw_bias + {batches_, fw_units_}, // fw_hidden_state {bw_units_, input_size_}, // bw_weights {bw_units_, bw_units_}, // bw_recurrent_weights - {bw_units_} // bw_bias + {bw_units_}, // bw_bias + {batches_, bw_units_} // bw_hidden_state }); } @@ -719,19 +721,6 @@ class BidirectionalRNNOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } - void ResetHiddenStates() { - const int fw_zero_buffer_size = fw_units_ * batches_; - std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]); - memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float)); - PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(), - fw_zero_buffer.get() + fw_zero_buffer_size); - const int bw_zero_buffer_size = bw_units_ * batches_; - std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]); - memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float)); - PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(), - bw_zero_buffer.get() + bw_zero_buffer_size); - } - std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); } std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); } @@ -774,7 +763,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - rnn.ResetHiddenStates(); const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); float* batch_start = rnn_input; float* batch_end = batch_start + input_sequence_size; @@ -813,8 +801,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - rnn.ResetHiddenStates(); - // Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the // following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1]. for (int i = 0; i < rnn.sequence_len(); i++) { @@ -880,8 +866,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - rnn.ResetHiddenStates(); - const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); const int output_sequence_size = output_size * rnn.sequence_len(); const int num_examples = 64; |