diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-30 18:05:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-30 18:08:13 -0700 |
commit | 85d30bfcf412bd1ca06fa33548344bf40eedb4ac (patch) | |
tree | 5201c3fe5d1da4e0dd4379df588fd5216377b11a /tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc | |
parent | 40721422bfc9cec546537799e16dd75f443d2db2 (diff) |
Internal change.
PiperOrigin-RevId: 194877173
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc | 47 |
1 files changed, 32 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 42941a97db..3c1256d3a6 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -66,10 +66,19 @@ constexpr int kProjectionWeightsTensor = 16; // Optional constexpr int kProjectionBiasTensor = 17; // Optional // Output tensors. -constexpr int kScratchBufferTensor = 0; -constexpr int kOutputStateTensor = 1; -constexpr int kCellStateTensor = 2; -constexpr int kOutputTensor = 3; +constexpr int kOutputStateTensor = 0; +constexpr int kCellStateTensor = 1; +constexpr int kOutputTensor = 2; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, 1, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<int*>(buffer); +} // Check that input tensor dimensions matches with each other. TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, @@ -220,12 +229,15 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, return kTfLiteOk; } -// Resize the output, state and scratch tensors based on the sizes of the input -// tensors. Also check that the size of the input tensors match each other. +// Resize the output and state tensors based on the sizes of the input tensors. +// Allocate a temprory scratch tensor. Also check that the sizes of the input +// tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); + // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. @@ -251,15 +263,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that input tensor dimensions matches with each other. CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); - // Get the pointer to output, state and scratch buffer tensors. + // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - // TODO(ghodrat): Modify this as soon as we have a finalized method for - // scratch buffers. - TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); - // Resize the output and output_state tensors. + // Resize the output, output_state and cell_state tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); output_size->data[0] = max_time; output_size->data[1] = n_batch; @@ -273,13 +282,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, output_state, output_state_size)); - // Resize the scratch buffer tensor. TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); cell_size->data[0] = n_batch; cell_size->data[1] = n_cell; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); + // Create a scratch buffer tensor. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]]; + scratch_buffer->type = input->type; + scratch_buffer->allocation_type = kTfLiteArenaRw; + // Mark state tensors as persistent tensors. output_state->allocation_type = kTfLiteArenaRwPersistent; cell_state->allocation_type = kTfLiteArenaRwPersistent; @@ -365,7 +381,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_peephole = (cell_to_output_weights != nullptr); // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]]; float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -439,7 +455,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace unidirectional_sequence_lstm TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + static TfLiteRegistration r = {unidirectional_sequence_lstm::Init, + unidirectional_sequence_lstm::Free, unidirectional_sequence_lstm::Prepare, unidirectional_sequence_lstm::Eval}; return &r; |