aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-02 19:40:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 19:44:35 -0800
commit39010bef7f72709a87a275060878baac815744c2 (patch)
tree5c38e0552002368e4942a986b847ab638078bcd5
parent0fab6e888c5f90de3e878566123c1906261ce27e (diff)
A more efficient implementation of the Op using batch operations.
PiperOrigin-RevId: 184367562
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD10
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc44
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h40
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc74
7 files changed, 135 insertions, 153 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 8c40adfae5..a8ef0daede 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -156,6 +156,7 @@ cc_library(
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:gemm_support",
+ "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index a0391e030f..2c5074eca3 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1];
- const int input_weights_stride = input_weights->dims->data[1];
- const int recurrent_weights_stride = recurrent_weights->dims->data[1];
-
- // For each batch
- for (int b = 0; b < batch_size; b++) {
- // Initialize the pointer to input, output and bias.
- const float* input_ptr_batch = input->data.f + b * input_size;
- float* output_ptr_batch = output->data.f + b * num_units;
- float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
-
- // Initialize input_weights and recurrent_weights.
- const float* input_weights_ptr = input_weights->data.f;
- const float* recurrent_weights_ptr = recurrent_weights->data.f;
-
- // Output = bias
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = bias_ptr[o];
- }
-
- // Output += input * input_weights
- for (int o = 0; o < num_units; o++) {
- for (int i = 0; i < input_size; i++) {
- output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
- }
- input_weights_ptr += input_weights_stride;
- }
-
- // Output += recurrent_weights * hidden_state
- for (int o = 0; o < num_units; o++) {
- for (int h = 0; h < num_units; h++) {
- output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
- }
- recurrent_weights_ptr += recurrent_weights_stride;
- }
-
- // Output = activation(Output) and update hidden_state
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] =
- (ActivationFunctor(params->activation))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
- }
- }
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f;
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch = input->data.f;
+ float* output_ptr_batch = output->data.f;
+ // Initialize input_weights and recurrent_weights.
+ const float* input_weights_ptr = input_weights->data.f;
+ const float* recurrent_weights_ptr = recurrent_weights->data.f;
+
+ kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ recurrent_weights_ptr, bias_ptr, input_size,
+ num_units, batch_size, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index f540816235..aa24c1f34c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -119,47 +120,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-namespace {
-// Performs one RNN computation step for the input specified by input_ptr_batch.
-// The RNN cell is specified by the pointers to its weights and biases, along
-// with the input size, number of units, strides, activation.
-// The pointers to the hidden state and the output are updated as a result.
-// TODO(mirkov): factor out this function to a shared library.
-void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
- const float* recurrent_weights_ptr, const float* bias_ptr,
- int input_size, int num_units, int input_weights_stride,
- int recurrent_weights_stride, TfLiteFusedActivation activation,
- float* hidden_state_ptr_batch, float* output_ptr_batch) {
- // Output = bias
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = bias_ptr[o];
- }
-
- // Output += input * input_weights
- for (int o = 0; o < num_units; o++) {
- for (int i = 0; i < input_size; i++) {
- output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
- }
- input_weights_ptr += input_weights_stride;
- }
-
- // Output += recurrent_weights * hidden_state
- for (int o = 0; o < num_units; o++) {
- for (int h = 0; h < num_units; h++) {
- output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
- }
- recurrent_weights_ptr += recurrent_weights_stride;
- }
-
- // Output = activation(Output) and update hidden_state
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
- }
-}
-} // namespace
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
@@ -189,15 +149,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const int input_size = input->dims->data[2];
const int fw_num_units = fw_input_weights->dims->data[0];
- const int fw_input_weights_stride = fw_input_weights->dims->data[1];
- const int fw_recurrent_weights_stride = fw_recurrent_weights->dims->data[1];
const float* fw_bias_ptr = fw_bias->data.f;
const float* fw_input_weights_ptr = fw_input_weights->data.f;
const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f;
const int bw_num_units = bw_input_weights->dims->data[0];
- const int bw_input_weights_stride = bw_input_weights->dims->data[1];
- const int bw_recurrent_weights_stride = bw_recurrent_weights->dims->data[1];
const float* bw_bias_ptr = bw_bias->data.f;
const float* bw_input_weights_ptr = bw_input_weights->data.f;
const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
@@ -212,10 +168,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
float* output_ptr_batch =
fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
- RnnStep(input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
- fw_bias_ptr, input_size, fw_num_units, fw_input_weights_stride,
- fw_recurrent_weights_stride, params->activation,
- fw_hidden_state_ptr_batch, output_ptr_batch);
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
+ fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
+ params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
float* bw_hidden_state_ptr_batch =
@@ -226,10 +182,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
float* output_ptr_batch =
bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
- RnnStep(input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
- bw_bias_ptr, input_size, bw_num_units, bw_input_weights_stride,
- bw_recurrent_weights_stride, params->activation,
- bw_hidden_state_ptr_batch, output_ptr_batch);
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
+ bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
+ params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 288f1f8bbc..4691a543e9 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -291,6 +291,16 @@ cc_library(
)
cc_library(
+ name = "kernel_utils",
+ srcs = ["kernel_utils.cc"],
+ hdrs = ["kernel_utils.h"],
+ deps = [
+ ":tensor_utils",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ],
+)
+
+cc_library(
name = "tensor_utils",
srcs = [
"tensor_utils.cc",
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
new file mode 100644
index 0000000000..510395126c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+
+namespace tflite {
+namespace kernel_utils {
+
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int num_units, int batch_size,
+ TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ // Output = bias
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
+ output_ptr_batch);
+ // Output += input * input_weights
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
+ output_ptr_batch, /*result_stride=*/1);
+ // Output += recurrent_weights * hidden_state
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
+ batch_size, output_ptr_batch, /*result_stride=*/1);
+ // Output = activation(Output) and update hidden_state
+ tensor_utils::ApplyActivationToVector(
+ output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
+ tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size,
+ hidden_state_ptr_batch);
+}
+
+} // namespace kernel_utils
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
new file mode 100644
index 0000000000..9872d4500b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+namespace kernel_utils {
+
+// Performs an RNN batch inference step for inputs specified by input_ptr_batch.
+// The RNN cell is specified by the pointers to its input and recurrent weights,
+// and biases, along with the input size, number of units, activation.
+//
+// The pointers to the hidden state and the output are updated as a result.
+//
+// The pointers with the suffix "_batch" point to data aligned in batch_major
+// order, and each step processes batch_size many inputs from input_ptr_batch,
+// and updates batch_size many outputs and hidden states.
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int num_units, int batch_size,
+ TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
+} // namespace kernel_utils
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 7ce87e4deb..ac00c37b67 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -88,42 +89,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-namespace {
-void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
- const float* recurrent_weights_ptr, const float* bias_ptr,
- int input_size, int num_units, int input_weights_stride,
- int recurrent_weights_stride, TfLiteFusedActivation activation,
- float* hidden_state_ptr_batch, float* output_ptr_batch) {
- // Output = bias
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = bias_ptr[o];
- }
-
- // Output += input * input_weights
- for (int o = 0; o < num_units; o++) {
- for (int i = 0; i < input_size; i++) {
- output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
- }
- input_weights_ptr += input_weights_stride;
- }
-
- // Output += recurrent_weights * hidden_state
- for (int o = 0; o < num_units; o++) {
- for (int h = 0; h < num_units; h++) {
- output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
- }
- recurrent_weights_ptr += recurrent_weights_stride;
- }
-
- // Output = activation(Output) and update hidden_state
- for (int o = 0; o < num_units; o++) {
- output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
- }
-}
-} // namespace
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
@@ -147,30 +112,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
(time_major) ? input->dims->data[0] : input->dims->data[1];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[2];
- const int input_weights_stride = input_weights->dims->data[1];
- const int recurrent_weights_stride = recurrent_weights->dims->data[1];
// Initialize input_weights and recurrent_weights.
const float* input_weights_ptr = input_weights->data.f;
const float* recurrent_weights_ptr = recurrent_weights->data.f;
if (time_major) {
- // Unroll the sequence
+ // Initialize the pointer to hidden state.
+ float* hidden_state_ptr_batch = hidden_state->data.f;
+ // Unroll the sequence and use batch batch operations for efficiency.
for (int s = 0; s < max_time; s++) {
- for (int b = 0; b < batch_size; b++) {
- // Initialize the pointer to hidden state.
- float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
- // Initialize the pointer to input and output.
- const float* input_ptr_batch =
- input->data.f + s * input_size * batch_size + b * input_size;
- float* output_ptr_batch =
- output->data.f + s * num_units * batch_size + b * num_units;
-
- RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
- bias_ptr, input_size, num_units, input_weights_stride,
- recurrent_weights_stride, params->activation,
- hidden_state_ptr_batch, output_ptr_batch);
- }
+ // Initialize the pointer to input and output.
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ float* output_ptr_batch = output->data.f + s * num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ recurrent_weights_ptr, bias_ptr, input_size,
+ num_units, batch_size, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
}
} else {
// For each batch
@@ -184,10 +144,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
float* output_ptr_batch =
output->data.f + b * num_units * max_time + s * num_units;
- RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
- bias_ptr, input_size, num_units, input_weights_stride,
- recurrent_weights_stride, params->activation,
- hidden_state_ptr_batch, output_ptr_batch);
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
+ input_size, num_units, /*batch_size=*/1, params->activation,
+ hidden_state_ptr_batch, output_ptr_batch);
}
}
}