diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc | 70 |
1 files changed, 43 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index a64ac42bc4..3ac0210f36 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -96,15 +96,23 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional constexpr int kBwProjectionBiasTensor = 34; // Optional // Output tensors. -constexpr int kFwScratchBufferTensor = 0; -constexpr int kFwOutputStateTensor = 1; -constexpr int kFwCellStateTensor = 2; -constexpr int kFwOutputTensor = 3; +constexpr int kFwOutputStateTensor = 0; +constexpr int kFwCellStateTensor = 1; +constexpr int kFwOutputTensor = 2; + +constexpr int kBwOutputStateTensor = 3; +constexpr int kBwCellStateTensor = 4; +constexpr int kBwOutputTensor = 5; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, 2, scratch_tensor_index); + return scratch_tensor_index; +} -constexpr int kBwScratchBufferTensor = 4; -constexpr int kBwOutputStateTensor = 5; -constexpr int kBwCellStateTensor = 6; -constexpr int kBwOutputTensor = 7; +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<int*>(buffer); +} // Check that input tensor dimensions matches with each other. TfLiteStatus CheckLstmTensorDimensions( @@ -296,9 +304,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Resize the output, state and scratch tensors based on the sizes of the input // tensors. Also check that the size of the input tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); + // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 35); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 8); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 6); // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. @@ -330,12 +340,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* fw_output_state = GetOutput(context, node, kFwOutputStateTensor); TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor); - // TODO(ghodrat): Modify this as soon as we have a finalized method for - // scratch buffers. - TfLiteTensor* fw_scratch_buffer = - GetOutput(context, node, kFwScratchBufferTensor); - // Resize the output and output_state tensors. + // Resize the output, output_state and cell_state tensors. TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3); fw_output_size->data[0] = max_time; fw_output_size->data[1] = n_batch; @@ -349,13 +355,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state, fw_output_state_size)); - // Resize the scratch buffer tensor. TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2); fw_cell_size->data[0] = n_batch; fw_cell_size->data[1] = n_fw_cell; TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, fw_cell_state, fw_cell_size)); + // Create a scratch buffer tensor. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* fw_scratch_buffer = + &context->tensors[node->temporaries->data[0]]; + fw_scratch_buffer->type = input->type; + fw_scratch_buffer->allocation_type = kTfLiteArenaRw; + // Mark state tensors as persistent tensors. fw_output_state->allocation_type = kTfLiteArenaRwPersistent; fw_cell_state->allocation_type = kTfLiteArenaRwPersistent; @@ -392,17 +406,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that input tensor dimensions matches with each other. CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell); - // Get the pointer to output, state and scratch buffer tensors. + // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); TfLiteTensor* bw_output_state = GetOutput(context, node, kBwOutputStateTensor); TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor); - // TODO(ghodrat): Modify this as soon as we have a finalized method for - // scratch buffers. - TfLiteTensor* bw_scratch_buffer = - GetOutput(context, node, kBwScratchBufferTensor); - // Resize the output and output_state tensors. + // Resize the output, output_state and cell_state tensors. TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); bw_output_size->data[0] = max_time; bw_output_size->data[1] = n_batch; @@ -416,13 +426,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state, bw_output_state_size)); - // Resize the scratch buffer tensor. TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2); bw_cell_size->data[0] = n_batch; bw_cell_size->data[1] = n_bw_cell; TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, bw_cell_state, bw_cell_size)); + // Create a scratch buffer tensor. + node->temporaries->data[1] = *(scratch_tensor_index) + 1; + TfLiteTensor* bw_scratch_buffer = + &context->tensors[node->temporaries->data[1]]; + bw_scratch_buffer->type = input->type; + bw_scratch_buffer->allocation_type = kTfLiteArenaRw; + // Mark state tensors as persistent tensors. bw_output_state->allocation_type = kTfLiteArenaRwPersistent; bw_cell_state->allocation_type = kTfLiteArenaRwPersistent; @@ -553,7 +569,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* fw_scratch_buffer = - GetOutput(context, node, kFwScratchBufferTensor); + &context->tensors[node->temporaries->data[0]]; float* fw_input_gate_scratch = nullptr; float* fw_cell_scratch = nullptr; float* fw_forget_gate_scratch = nullptr; @@ -624,7 +640,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* bw_scratch_buffer = - GetOutput(context, node, kBwScratchBufferTensor); + &context->tensors[node->temporaries->data[1]]; float* bw_input_gate_scratch = nullptr; float* bw_cell_scratch = nullptr; float* bw_forget_gate_scratch = nullptr; @@ -691,9 +707,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace bidirectional_sequence_lstm TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - bidirectional_sequence_lstm::Prepare, - bidirectional_sequence_lstm::Eval}; + static TfLiteRegistration r = { + bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free, + bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval}; return &r; } |