aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/lstm.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 18:05:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 18:08:13 -0700
commit85d30bfcf412bd1ca06fa33548344bf40eedb4ac (patch)
tree5201c3fe5d1da4e0dd4379df588fd5216377b11a /tensorflow/contrib/lite/kernels/lstm.cc
parent40721422bfc9cec546537799e16dd75f443d2db2 (diff)
Internal change.
PiperOrigin-RevId: 194877173
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc49
1 files changed, 33 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 8cf1165135..668226e674 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -66,10 +66,19 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
@@ -220,12 +229,15 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@@ -250,15 +262,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
@@ -271,13 +280,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the output, state and scratch buffer tensors.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -362,7 +378,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -433,8 +450,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- lstm::Prepare, lstm::Eval};
+ static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
+ lstm::Eval};
return &r;
}