diff options
author | Alan Chiao <alanchiao@google.com> | 2018-08-27 17:01:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 17:05:17 -0700 |
commit | 34870d54c0e1b505212da6ace722c6d9f9e8891b (patch) | |
tree | 81fb31e1774c50cd9f114d4be93465c8e0d470df /tensorflow/contrib/lite/kernels/basic_rnn.cc | |
parent | f481d2bed293d8791069711cd08084be3b079222 (diff) |
Update RNN to support state API.
PiperOrigin-RevId: 210457365
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/basic_rnn.cc | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index c09b15b3d2..c5a5c0182f 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -31,8 +31,10 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kRecurrentWeightsTensor = 2; constexpr int kBiasTensor = 3; -constexpr int kHiddenStateTensor = 0; -constexpr int kOutputTensor = 1; +constexpr int kHiddenStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; @@ -46,14 +48,16 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* recurrent_weights = GetInput(context, node, kRecurrentWeightsTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + const TfLiteTensor* hidden_state = + GetInput(context, node, kHiddenStateTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -65,20 +69,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type); + TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2); + TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units); - TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - // Resize state. - TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); - hidden_state_size_array->data[0] = batch_size; - hidden_state_size_array->data[1] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, - hidden_state_size_array)); - - // Mark hidden state as a persistent tensor. - hidden_state->allocation_type = kTfLiteArenaRwPersistent; - // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); output_size_array->data[0] = batch_size; @@ -205,7 +201,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* recurrent_weights = GetInput(context, node, kRecurrentWeightsTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); - TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* hidden_state = + &context->tensors[node->inputs->data[kHiddenStateTensor]]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // We already checked that weight types are consistent, so branch on one. |