aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-01 16:27:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 16:30:28 -0700
commitb31498a054d55ce328a2820fd403af764c482500 (patch)
tree91b8513149a36ae042e2a1b51f9e284701bbdcec /tensorflow/contrib/lite/kernels/lstm.cc
parent73ec24e8b75ba4f73a06756502d8bf86b2a6828b (diff)
Support 5-inputs LSTM kernel in TFLite (float only).
PiperOrigin-RevId: 198943559
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc190
1 files changed, 181 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 990b3da055..9aae3e571b 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -34,6 +36,17 @@ namespace ops {
namespace builtin {
namespace lstm {
+struct OpData {
+ // Which kernel type to use. Full kernel (18-inputs) or basic kernel
+ // (5-inputs).
+ TfLiteLSTMKernelType kernel_type;
+ // Only used by full kernel.
+ int scratch_tensor_index;
+};
+
+// For full inputs kernel (18-inputs).
+namespace full {
+
// Input Tensors of size {n_batch, n_input}
constexpr int kInputTensor = 0;
@@ -71,13 +84,10 @@ constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
- return scratch_tensor_index;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
+ auto* op_data = new OpData;
+ op_data->kernel_type = kTfLiteLSTMFullKernel;
+ context->AddTensors(context, 1, &op_data->scratch_tensor_index);
+ return op_data;
}
// Check that input tensor dimensions matches with each other.
@@ -233,7 +243,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
@@ -289,7 +299,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[0] = *scratch_tensor_index;
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
@@ -447,6 +457,168 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+} // namespace full
+
+// For basic kernel (5-inputs).
+namespace basic {
+
+enum InputTensor {
+ kInputData = 0,
+ kInputPrevActivation = 1,
+ kInputWeights = 2,
+ kInputBiases = 3,
+ kInputPrevState = 4,
+ kInputNum = 5,
+};
+
+enum OutputTensor {
+ kOutputActivation = 0,
+ kOutputState = 1,
+ kOutputConcatTemp = 2,
+ kOutputActivationTemp = 3,
+ kOutputNum = 4,
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+ op_data->kernel_type = kTfLiteLSTMBasicKernel;
+ // `scratch_tensor_index` is unused in this kernel.
+ op_data->scratch_tensor_index = -1;
+ return op_data;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
+ TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
+
+ // Only Float32 is supportted currently.
+ // TODO(ycling): Implement quantize uint8 support.
+ for (int index = 0; index < node->inputs->size; ++index) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32);
+ }
+
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
+ const int num_batches = input->dims->data[0];
+
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
+
+ TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
+
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(
+ context, activation_out,
+ TfLiteIntArrayCopy(prev_activation->dims)));
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, state_out,
+ TfLiteIntArrayCopy(prev_state->dims)));
+ TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
+ concat_temp_size->data[0] = num_batches;
+ concat_temp_size->data[1] = weights->dims->data[1];
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, concat_temp, concat_temp_size));
+ TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
+ activation_temp_size->data[0] = num_batches;
+ activation_temp_size->data[1] = weights->dims->data[0];
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
+ activation_temp_size));
+
+ // Set the state tensors as persistent.
+ for (auto index : {kInputPrevActivation, kInputPrevState}) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ tensor->allocation_type = kTfLiteArenaRwPersistent;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ optimized_ops::LstmCell(
+ // Inputs.
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
+ GetTensorData<float>(weights), GetTensorDims(weights),
+ GetTensorData<float>(bias), GetTensorDims(bias),
+ GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ // Outputs.
+ GetTensorData<float>(state_out), GetTensorDims(state_out),
+ GetTensorData<float>(activation_out), GetTensorDims(activation_out),
+ GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
+ GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+
+ // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs
+ // LSTM kernel.
+ memcpy(prev_activation->data.raw, activation_out->data.raw,
+ activation_out->bytes);
+ memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
+
+ return kTfLiteOk;
+}
+
+} // namespace basic
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
+ switch (params->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Init(context, buffer, length);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Init(context, buffer, length);
+ }
+}
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Prepare(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Prepare(context, node);
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Eval(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Eval(context, node);
+ }
+}
+
} // namespace lstm
TfLiteRegistration* Register_LSTM() {