diff options
author | Alan Chiao <alanchiao@google.com> | 2018-08-27 17:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 17:05:17 -0700 |
commit | 34870d54c0e1b505212da6ace722c6d9f9e8891b (patch) | |
tree | 81fb31e1774c50cd9f114d4be93465c8e0d470df /tensorflow/contrib/lite/kernels/basic_rnn_test.cc | |
parent | f481d2bed293d8791069711cd08084be3b079222 (diff) |
Update RNN to support state API.
PiperOrigin-RevId: 210457365
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/basic_rnn_test.cc | 21 |
1 files changed, 6 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index 96465fcaf0..d179735404 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -181,15 +181,16 @@ class RNNOpModel : public SingleOpModel { weights_ = AddInput(weights); recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); - hidden_state_ = AddOutput(TensorType_FLOAT32); + hidden_state_ = AddInput(TensorType_FLOAT32, true); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp( BuiltinOperator_RNN, BuiltinOptions_RNNOptions, CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); - BuildInterpreter({{batches_, input_size_}, - {units_, input_size_}, - {units_, units_}, - {units_}}); + BuildInterpreter({{batches_, input_size_}, // input tensor + {units_, input_size_}, // weights tensor + {units_, units_}, // recurrent weights tensor + {units_}, // bias tensor + {batches_, units_}}); // hidden state tensor } void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } @@ -210,14 +211,6 @@ class RNNOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } - void ResetHiddenState() { - const int zero_buffer_size = units_ * batches_; - std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(hidden_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - std::vector<float> GetOutput() { return ExtractVector<float>(output_); } int input_size() { return input_size_; } @@ -258,7 +251,6 @@ TEST(RnnOpTest, BlackBoxTest) { rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches()); @@ -286,7 +278,6 @@ TEST(HybridRnnOpTest, BlackBoxTest) { rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches()); |