diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-06-01 16:27:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-01 16:30:28 -0700 |
commit | b31498a054d55ce328a2820fd403af764c482500 (patch) | |
tree | 91b8513149a36ae042e2a1b51f9e284701bbdcec /tensorflow/contrib/lite/kernels/lstm.cc | |
parent | 73ec24e8b75ba4f73a06756502d8bf86b2a6828b (diff) |
Support 5-inputs LSTM kernel in TFLite (float only).
PiperOrigin-RevId: 198943559
Diffstat (limited to 'tensorflow/contrib/lite/kernels/lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/lstm.cc | 190 |
1 files changed, 181 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 990b3da055..9aae3e571b 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -34,6 +36,17 @@ namespace ops { namespace builtin { namespace lstm { +struct OpData { + // Which kernel type to use. Full kernel (18-inputs) or basic kernel + // (5-inputs). + TfLiteLSTMKernelType kernel_type; + // Only used by full kernel. + int scratch_tensor_index; +}; + +// For full inputs kernel (18-inputs). +namespace full { + // Input Tensors of size {n_batch, n_input} constexpr int kInputTensor = 0; @@ -71,13 +84,10 @@ constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast<int*>(buffer); + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMFullKernel; + context->AddTensors(context, 1, &op_data->scratch_tensor_index); + return op_data; } // Check that input tensor dimensions matches with each other. @@ -233,7 +243,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); + OpData* op_data = reinterpret_cast<OpData*>(node->user_data); // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); @@ -289,7 +299,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Create a scratch buffer tensor. TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -447,6 +457,168 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace full + +// For basic kernel (5-inputs). +namespace basic { + +enum InputTensor { + kInputData = 0, + kInputPrevActivation = 1, + kInputWeights = 2, + kInputBiases = 3, + kInputPrevState = 4, + kInputNum = 5, +}; + +enum OutputTensor { + kOutputActivation = 0, + kOutputState = 1, + kOutputConcatTemp = 2, + kOutputActivationTemp = 3, + kOutputNum = 4, +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMBasicKernel; + // `scratch_tensor_index` is unused in this kernel. + op_data->scratch_tensor_index = -1; + return op_data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, node->inputs->size == kInputNum); + TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); + + // Only Float32 is supportted currently. + // TODO(ycling): Implement quantize uint8 support. + for (int index = 0; index < node->inputs->size; ++index) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32); + } + + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TF_LITE_ENSURE_EQ(context, input->dims->size, 2); + const int num_batches = input->dims->data[0]; + + TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + + TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + + TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + TF_LITE_ENSURE_OK(context, context->ResizeTensor( + context, activation_out, + TfLiteIntArrayCopy(prev_activation->dims))); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, state_out, + TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); + concat_temp_size->data[0] = num_batches; + concat_temp_size->data[1] = weights->dims->data[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, concat_temp, concat_temp_size)); + TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); + activation_temp_size->data[0] = num_batches; + activation_temp_size->data[1] = weights->dims->data[0]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, + activation_temp_size)); + + // Set the state tensors as persistent. + for (auto index : {kInputPrevActivation, kInputPrevState}) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + tensor->allocation_type = kTfLiteArenaRwPersistent; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + optimized_ops::LstmCell( + // Inputs. + GetTensorData<float>(input), GetTensorDims(input), + GetTensorData<float>(prev_activation), GetTensorDims(prev_activation), + GetTensorData<float>(weights), GetTensorDims(weights), + GetTensorData<float>(bias), GetTensorDims(bias), + GetTensorData<float>(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData<float>(state_out), GetTensorDims(state_out), + GetTensorData<float>(activation_out), GetTensorDims(activation_out), + GetTensorData<float>(concat_temp), GetTensorDims(concat_temp), + GetTensorData<float>(activation_temp), GetTensorDims(activation_temp)); + + // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs + // LSTM kernel. + memcpy(prev_activation->data.raw, activation_out->data.raw, + activation_out->bytes); + memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); + + return kTfLiteOk; +} + +} // namespace basic + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer); + switch (params->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Init(context, buffer, length); + case kTfLiteLSTMBasicKernel: + return basic::Init(context, buffer, length); + } +} +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast<const OpData*>(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Prepare(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Prepare(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast<const OpData*>(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Eval(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Eval(context, node); + } +} + } // namespace lstm TfLiteRegistration* Register_LSTM() { |