aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 07:41:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 07:45:47 -0700
commitd5d02f078ff8d5f4c5541c9281e1a0e027ce9f0c (patch)
treeb1637242007789afdb393dbd97edf117219709cc /tensorflow/contrib/lite/kernels
parent8dfb7532f8278e53a86a847ba6aa9c441f7b021b (diff)
Update bidirectional RNN to support state API.
PiperOrigin-RevId: 210719446
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc60
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc26
2 files changed, 30 insertions, 56 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 4162d9bb88..c65bc33d08 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -36,14 +36,14 @@ constexpr int kInputTensor = 0;
constexpr int kFwWeightsTensor = 1;
constexpr int kFwRecurrentWeightsTensor = 2;
constexpr int kFwBiasTensor = 3;
-constexpr int kBwWeightsTensor = 4;
-constexpr int kBwRecurrentWeightsTensor = 5;
-constexpr int kBwBiasTensor = 6;
-// State and output tensors.
-constexpr int kFwHiddenStateTensor = 0;
-constexpr int kFwOutputTensor = 1;
-constexpr int kBwHiddenStateTensor = 2;
-constexpr int kBwOutputTensor = 3;
+constexpr int kFwHiddenStateTensor = 4;
+constexpr int kBwWeightsTensor = 5;
+constexpr int kBwRecurrentWeightsTensor = 6;
+constexpr int kBwBiasTensor = 7;
+constexpr int kBwHiddenStateTensor = 8;
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
@@ -57,8 +57,8 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 9);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -66,11 +66,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_recurrent_weights =
GetInput(context, node, kFwRecurrentWeightsTensor);
const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* fw_hidden_state =
+ GetInput(context, node, kFwHiddenStateTensor);
const TfLiteTensor* bw_input_weights =
GetInput(context, node, kBwWeightsTensor);
const TfLiteTensor* bw_recurrent_weights =
GetInput(context, node, kBwRecurrentWeightsTensor);
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+ const TfLiteTensor* bw_hidden_state =
+ GetInput(context, node, kBwHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -88,31 +92,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_bias->dims->data[0]);
TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
bw_bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- // Resize hidden states.
- TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- fw_hidden_state_size_array->data[0] = batch_size;
- fw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* fw_hidden_state =
- GetOutput(context, node, kFwHiddenStateTensor);
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
- fw_hidden_state_size_array));
-
- TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- bw_hidden_state_size_array->data[0] = batch_size;
- bw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* bw_hidden_state =
- GetOutput(context, node, kBwHiddenStateTensor);
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
- bw_hidden_state_size_array));
-
- // Mark hidden states as a persistent tensor.
- fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
const bool is_hybrid_op =
(fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
@@ -326,12 +315,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kBwRecurrentWeightsTensor);
const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteTensor* fw_hidden_state =
- GetOutput(context, node, kFwHiddenStateTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ const_cast<TfLiteTensor*>(GetInput(context, node, kFwHiddenStateTensor));
TfLiteTensor* bw_hidden_state =
- GetOutput(context, node, kBwHiddenStateTensor);
+ const_cast<TfLiteTensor*>(GetInput(context, node, kBwHiddenStateTensor));
+
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
switch (fw_input_weights->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 911b108eaa..03236dbcdc 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -664,12 +664,12 @@ class BidirectionalRNNOpModel : public SingleOpModel {
fw_weights_ = AddInput(TensorType_FLOAT32);
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32);
- fw_hidden_state_ = AddOutput(TensorType_FLOAT32);
+ fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
fw_output_ = AddOutput(TensorType_FLOAT32);
bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32);
- bw_hidden_state_ = AddOutput(TensorType_FLOAT32);
+ bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
@@ -681,9 +681,11 @@ class BidirectionalRNNOpModel : public SingleOpModel {
{fw_units_, input_size_}, // fw_weights
{fw_units_, fw_units_}, // fw_recurrent_weights
{fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
{bw_units_, input_size_}, // bw_weights
{bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_} // bw_bias
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_} // bw_hidden_state
});
}
@@ -719,19 +721,6 @@ class BidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenStates() {
- const int fw_zero_buffer_size = fw_units_ * batches_;
- std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]);
- memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float));
- PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(),
- fw_zero_buffer.get() + fw_zero_buffer_size);
- const int bw_zero_buffer_size = bw_units_ * batches_;
- std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]);
- memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float));
- PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(),
- bw_zero_buffer.get() + bw_zero_buffer_size);
- }
-
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
@@ -774,7 +763,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
float* batch_end = batch_start + input_sequence_size;
@@ -813,8 +801,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
// Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
// following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
for (int i = 0; i < rnn.sequence_len(); i++) {
@@ -880,8 +866,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
const int output_sequence_size = output_size * rnn.sequence_len();
const int num_examples = 64;