aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 10:35:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:42:47 -0700
commit287e82b5624c4ad0f1c81c16127e68ed03b1140e (patch)
tree9b6349257bdb9d09745c68fb5b0467e6db62e9f2 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
parent150dee25d82589ca109957cc996efbd2a236e044 (diff)
Update unidirectional sequential RNN to support state API.
PiperOrigin-RevId: 210746360
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc20
1 files changed, 5 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index 0adab837b0..6b48e3fff7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -183,7 +183,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
@@ -194,12 +194,14 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
BuildInterpreter({{sequence_len_, batches_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units}});
} else {
BuildInterpreter({{batches_, sequence_len_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units_}});
}
}
@@ -221,14 +223,6 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -273,7 +267,6 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -299,7 +292,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -326,7 +318,6 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
@@ -356,7 +347,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();