diff options
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc | 60 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc | 26 |
2 files changed, 30 insertions, 56 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index 4162d9bb88..c65bc33d08 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -36,14 +36,14 @@ constexpr int kInputTensor = 0; constexpr int kFwWeightsTensor = 1; constexpr int kFwRecurrentWeightsTensor = 2; constexpr int kFwBiasTensor = 3; -constexpr int kBwWeightsTensor = 4; -constexpr int kBwRecurrentWeightsTensor = 5; -constexpr int kBwBiasTensor = 6; -// State and output tensors. -constexpr int kFwHiddenStateTensor = 0; -constexpr int kFwOutputTensor = 1; -constexpr int kBwHiddenStateTensor = 2; -constexpr int kBwOutputTensor = 3; +constexpr int kFwHiddenStateTensor = 4; +constexpr int kBwWeightsTensor = 5; +constexpr int kBwRecurrentWeightsTensor = 6; +constexpr int kBwBiasTensor = 7; +constexpr int kBwHiddenStateTensor = 8; +// Output tensors. +constexpr int kFwOutputTensor = 0; +constexpr int kBwOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; @@ -57,8 +57,8 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 7); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 9); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* fw_input_weights = @@ -66,11 +66,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* fw_recurrent_weights = GetInput(context, node, kFwRecurrentWeightsTensor); const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); + const TfLiteTensor* fw_hidden_state = + GetInput(context, node, kFwHiddenStateTensor); const TfLiteTensor* bw_input_weights = GetInput(context, node, kBwWeightsTensor); const TfLiteTensor* bw_recurrent_weights = GetInput(context, node, kBwRecurrentWeightsTensor); const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); + const TfLiteTensor* bw_hidden_state = + GetInput(context, node, kBwHiddenStateTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -88,31 +92,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_bias->dims->data[0]); TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1], bw_bias->dims->data[0]); + TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2); + TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units); + TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2); + TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units); TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); - // Resize hidden states. - TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2); - fw_hidden_state_size_array->data[0] = batch_size; - fw_hidden_state_size_array->data[1] = fw_num_units; - TfLiteTensor* fw_hidden_state = - GetOutput(context, node, kFwHiddenStateTensor); - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state, - fw_hidden_state_size_array)); - - TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2); - bw_hidden_state_size_array->data[0] = batch_size; - bw_hidden_state_size_array->data[1] = fw_num_units; - TfLiteTensor* bw_hidden_state = - GetOutput(context, node, kBwHiddenStateTensor); - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state, - bw_hidden_state_size_array)); - - // Mark hidden states as a persistent tensor. - fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; - bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; - const bool is_hybrid_op = (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32); @@ -326,12 +315,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetInput(context, node, kBwRecurrentWeightsTensor); const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); - TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* fw_hidden_state = - GetOutput(context, node, kFwHiddenStateTensor); - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + const_cast<TfLiteTensor*>(GetInput(context, node, kFwHiddenStateTensor)); TfLiteTensor* bw_hidden_state = - GetOutput(context, node, kBwHiddenStateTensor); + const_cast<TfLiteTensor*>(GetInput(context, node, kBwHiddenStateTensor)); + + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); switch (fw_input_weights->type) { case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc index 911b108eaa..03236dbcdc 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -664,12 +664,12 @@ class BidirectionalRNNOpModel : public SingleOpModel { fw_weights_ = AddInput(TensorType_FLOAT32); fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); fw_bias_ = AddInput(TensorType_FLOAT32); - fw_hidden_state_ = AddOutput(TensorType_FLOAT32); + fw_hidden_state_ = AddInput(TensorType_FLOAT32, true); fw_output_ = AddOutput(TensorType_FLOAT32); bw_weights_ = AddInput(TensorType_FLOAT32); bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); bw_bias_ = AddInput(TensorType_FLOAT32); - bw_hidden_state_ = AddOutput(TensorType_FLOAT32); + bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); bw_output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_SequenceRNNOptions, @@ -681,9 +681,11 @@ class BidirectionalRNNOpModel : public SingleOpModel { {fw_units_, input_size_}, // fw_weights {fw_units_, fw_units_}, // fw_recurrent_weights {fw_units_}, // fw_bias + {batches_, fw_units_}, // fw_hidden_state {bw_units_, input_size_}, // bw_weights {bw_units_, bw_units_}, // bw_recurrent_weights - {bw_units_} // bw_bias + {bw_units_}, // bw_bias + {batches_, bw_units_} // bw_hidden_state }); } @@ -719,19 +721,6 @@ class BidirectionalRNNOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } - void ResetHiddenStates() { - const int fw_zero_buffer_size = fw_units_ * batches_; - std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]); - memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float)); - PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(), - fw_zero_buffer.get() + fw_zero_buffer_size); - const int bw_zero_buffer_size = bw_units_ * batches_; - std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]); - memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float)); - PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(), - bw_zero_buffer.get() + bw_zero_buffer_size); - } - std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); } std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); } @@ -774,7 +763,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - rnn.ResetHiddenStates(); const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); float* batch_start = rnn_input; float* batch_end = batch_start + input_sequence_size; @@ -813,8 +801,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - rnn.ResetHiddenStates(); - // Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the // following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1]. for (int i = 0; i < rnn.sequence_len(); i++) { @@ -880,8 +866,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) { rnn.SetFwRecurrentWeights(recurrent_weights); rnn.SetBwRecurrentWeights(recurrent_weights); - rnn.ResetHiddenStates(); - const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); const int output_sequence_size = output_size * rnn.sequence_len(); const int num_examples = 64; |