aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/basic_rnn.cc
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2017-11-10 10:35:35 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit0b15439f8f0f2d4755587f4096c3ea04cb199d23 (patch)
tree9aa4fc8162bf9b4ee50112a7b85703f70ca4df08 /tensorflow/contrib/lite/kernels/basic_rnn.cc
parent7ac140a5845553275427162aabd9d54987144b4a (diff)
Internal Change.
PiperOrigin-RevId: 175307445
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc161
1 files changed, 161 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
new file mode 100644
index 0000000000..3cee43c68b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <cstdio>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace rnn {
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kRecurrentWeightsTensor = 2;
+constexpr int kBiasTensor = 3;
+constexpr int KHiddenStateTensor = 0;
+constexpr int kOutputTensor = 1;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check we have all the inputs and outputs we need.
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* input_weights =
+ &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* recurrent_weights =
+ &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
+ TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+
+ // Check all the parameters of tensor match within themselves and match the
+ // input configuration.
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
+ TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
+ TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
+
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+
+ // Resize state.
+ TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
+ hidden_state_size_array->data[0] = batch_size;
+ hidden_state_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
+ hidden_state_size_array));
+
+ // Mark hidden state as a persistent tensor.
+ hidden_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // Resize output.
+ TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
+ output_size_array->data[0] = batch_size;
+ output_size_array->data[1] = num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output,
+ output_size_array));
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
+
+ TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ TfLiteTensor* input_weights =
+ &context->tensors[node->inputs->data[kWeightsTensor]];
+ TfLiteTensor* recurrent_weights =
+ &context->tensors[node->inputs->data[kRecurrentWeightsTensor]];
+ TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]];
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->outputs->data[KHiddenStateTensor]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]];
+
+ // Initialize the pointer bias.
+ const float* bias_ptr = bias->data.f;
+
+ const int batch_size = input->dims->data[0];
+ const int num_units = input_weights->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int input_weights_stride = input_weights->dims->data[1];
+ const int recurrent_weights_stride = recurrent_weights->dims->data[1];
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Initialize the pointer to input, output and bias.
+ const float* input_ptr_batch = input->data.f + b * input_size;
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
+
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
+
+ // Output = bias
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] = bias_ptr[o];
+ }
+
+ // Output += input * input_weights
+ for (int o = 0; o < num_units; o++) {
+ for (int i = 0; i < input_size; i++) {
+ output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
+ }
+ input_weights_ptr += input_weights_stride;
+ }
+
+ // Output += recurrent_weights * hidden_state
+ for (int o = 0; o < num_units; o++) {
+ for (int h = 0; h < num_units; h++) {
+ output_ptr_batch[o] +=
+ hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+ }
+ recurrent_weights_ptr += recurrent_weights_stride;
+ }
+
+ // Output = activation(Output) and update hidden_state
+ for (int o = 0; o < num_units; o++) {
+ output_ptr_batch[o] =
+ (ActivationFunctor(params->activation))(output_ptr_batch[o]);
+ hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace rnn
+
+TfLiteRegistration* Register_RNN() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ rnn::Prepare, rnn::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite