diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-29 10:35:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 10:42:47 -0700 |
commit | 287e82b5624c4ad0f1c81c16127e68ed03b1140e (patch) | |
tree | 9b6349257bdb9d09745c68fb5b0467e6db62e9f2 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | |
parent | 150dee25d82589ca109957cc996efbd2a236e044 (diff) |
Update unidirectional sequential RNN to support state API.
PiperOrigin-RevId: 210746360
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | 20 |
1 files changed, 5 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index 0adab837b0..6b48e3fff7 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -183,7 +183,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel { weights_ = AddInput(weights); recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); - hidden_state_ = AddOutput(TensorType_FLOAT32); + hidden_state_ = AddInput(TensorType_FLOAT32, true); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_SequenceRNNOptions, @@ -194,12 +194,14 @@ class UnidirectionalRNNOpModel : public SingleOpModel { BuildInterpreter({{sequence_len_, batches_, input_size_}, {units_, input_size_}, {units_, units_}, - {units_}}); + {units_}, + {batches_, units}}); } else { BuildInterpreter({{batches_, sequence_len_, input_size_}, {units_, input_size_}, {units_, units_}, - {units_}}); + {units_}, + {batches_, units_}}); } } @@ -221,14 +223,6 @@ class UnidirectionalRNNOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } - void ResetHiddenState() { - const int zero_buffer_size = units_ * batches_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(hidden_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - std::vector<float> GetOutput() { return ExtractVector<float>(output_); } int input_size() { return input_size_; } @@ -273,7 +267,6 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) { rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); float* batch_start = rnn_input; @@ -299,7 +292,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) { rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); float* batch_start = rnn_input; @@ -326,7 +318,6 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) { rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); for (int i = 0; i < rnn.sequence_len(); i++) { float* batch_start = rnn_input + i * rnn.input_size(); @@ -356,7 +347,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) { rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); for (int i = 0; i < rnn.sequence_len(); i++) { float* batch_start = rnn_input + i * rnn.input_size(); |