aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc70
1 files changed, 43 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a64ac42bc4..3ac0210f36 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -96,15 +96,23 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
constexpr int kBwProjectionBiasTensor = 34; // Optional
// Output tensors.
-constexpr int kFwScratchBufferTensor = 0;
-constexpr int kFwOutputStateTensor = 1;
-constexpr int kFwCellStateTensor = 2;
-constexpr int kFwOutputTensor = 3;
+constexpr int kFwOutputStateTensor = 0;
+constexpr int kFwCellStateTensor = 1;
+constexpr int kFwOutputTensor = 2;
+
+constexpr int kBwOutputStateTensor = 3;
+constexpr int kBwCellStateTensor = 4;
+constexpr int kBwOutputTensor = 5;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
-constexpr int kBwScratchBufferTensor = 4;
-constexpr int kBwOutputStateTensor = 5;
-constexpr int kBwCellStateTensor = 6;
-constexpr int kBwOutputTensor = 7;
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckLstmTensorDimensions(
@@ -296,9 +304,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Resize the output, state and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 8);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -330,12 +340,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_output_state =
GetOutput(context, node, kFwOutputStateTensor);
TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
@@ -349,13 +355,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
fw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
fw_cell_size->data[0] = n_batch;
fw_cell_size->data[1] = n_fw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ &context->tensors[node->temporaries->data[0]];
+ fw_scratch_buffer->type = input->type;
+ fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -392,17 +406,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
TfLiteTensor* bw_output_state =
GetOutput(context, node, kBwOutputStateTensor);
TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
@@ -416,13 +426,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
bw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
bw_cell_size->data[0] = n_batch;
bw_cell_size->data[1] = n_bw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Create a scratch buffer tensor.
+ node->temporaries->data[1] = *(scratch_tensor_index) + 1;
+ TfLiteTensor* bw_scratch_buffer =
+ &context->tensors[node->temporaries->data[1]];
+ bw_scratch_buffer->type = input->type;
+ bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -553,7 +569,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[0]];
float* fw_input_gate_scratch = nullptr;
float* fw_cell_scratch = nullptr;
float* fw_forget_gate_scratch = nullptr;
@@ -624,7 +640,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[1]];
float* bw_input_gate_scratch = nullptr;
float* bw_cell_scratch = nullptr;
float* bw_forget_gate_scratch = nullptr;
@@ -691,9 +707,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace bidirectional_sequence_lstm
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_lstm::Prepare,
- bidirectional_sequence_lstm::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
+ bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
return &r;
}