From 5fab6df2788937bee1cce3a4e8f5b9d1db7497ec Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Tue, 19 Jun 2018 12:35:44 -0700 Subject: Support Variable Tensor API in LSTM Full kernel. TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs. PiperOrigin-RevId: 201222516 --- tensorflow/contrib/lite/kernels/lstm.cc | 161 ++++++++++++++------- tensorflow/contrib/lite/kernels/lstm_test.cc | 8 + .../contrib/lite/kernels/optional_tensor_test.cc | 8 + tensorflow/contrib/lite/kernels/test_util.cc | 5 +- tensorflow/contrib/lite/kernels/test_util.h | 11 +- tensorflow/contrib/lite/testing/tflite_driver.cc | 6 +- .../identify_lstm_split_inputs.cc | 10 +- .../lite/toco/graph_transformations/lstm_utils.h | 6 +- tensorflow/contrib/lite/toco/tflite/BUILD | 1 + tensorflow/contrib/lite/toco/tflite/operator.cc | 17 ++- 10 files changed, 158 insertions(+), 75 deletions(-) (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index eb26a02455..1dda97c101 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -37,14 +37,17 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18-inputs) or basic kernel - // (5-inputs). + // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // (5 inputs). TfLiteLSTMKernelType kernel_type; - // Only used by full kernel. + + // These fields are only used by full kernel. + int activation_state_tensor_index; + int cell_state_tensor_index; int scratch_tensor_index; }; -// For full inputs kernel (18-inputs). +// For full inputs kernel (18 or 20 inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -78,7 +81,16 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional +// If the node has 20 inputs, the following 2 tensors are used as state tensors. +// These are defined as variable tensors, and will be modified by this op. +constexpr int kInputActivationStateTensor = 18; +constexpr int kInputCellStateTensor = 19; + // Output tensors. +// * If the node has 18 inputs, these 2 tensors are used as state tensors. +// * If the node has 20 inputs, these 2 tensors are ignored. +// TODO(ycling): Make the 2 output state tensors optional, and propagate the +// state to output tensors when the 2 tensors present. constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; @@ -246,10 +258,31 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); + // True if the node is using input variable state tensors. It means: + // * The state tensors are defined as inputs. In this case it would be the + // 19th and 20th input tensors. + // * Otherwise, the output tensors are used to store states. + bool use_input_variable_states; + if (node->inputs->size == 20) { + use_input_variable_states = true; + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = + node->inputs->data[kInputCellStateTensor]; + } else if (node->inputs->size == 18) { + use_input_variable_states = false; + op_data->activation_state_tensor_index = + node->outputs->data[kOutputStateTensor]; + op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -274,34 +307,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that input tensor dimensions matches with each other. CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); - // Get the pointer to output, output_state and cell_state tensors. + // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - // Resize the output, output_state and cell_state tensors. + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + if (use_input_variable_states) { + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), + n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + } else { + // If the state tensors are outputs, this function takes the + // responsibility to resize the state tensors. + TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); + activation_state_size->data[0] = n_batch; + activation_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, + activation_state_size)); + + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + // Mark state tensors as persistent tensors. + activation_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + } + + // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); output_size->data[0] = n_batch; output_size->data[1] = n_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); - TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); - output_state_size->data[0] = n_batch; - output_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, output_state, output_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); - - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - // The weights are of consistent type, so it suffices to check one. // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && @@ -337,7 +383,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (is_hybrid_op) { // Allocate temporary tensors to store quantized values of input, - // output_state and cell_state tensors. + // activation_state and cell_state tensors. node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); input_quantized->type = kTfLiteUInt8; @@ -348,17 +394,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_quantized_size)); } node->temporaries->data[2] = op_data->scratch_tensor_index + 2; - TfLiteTensor* output_state_quantized = + TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); - output_state_quantized->type = kTfLiteUInt8; - output_state_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(output_state_quantized->dims, - output_state->dims)) { - TfLiteIntArray* output_state_quantized_size = - TfLiteIntArrayCopy(output_state->dims); - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, output_state_quantized, - output_state_quantized_size)); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); } node->temporaries->data[3] = op_data->scratch_tensor_index + 3; TfLiteTensor* cell_state_quantized = @@ -438,7 +484,7 @@ TfLiteStatus EvalFloat( const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -499,7 +545,7 @@ TfLiteStatus EvalFloat( const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; @@ -512,8 +558,8 @@ TfLiteStatus EvalFloat( cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + activation_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); return kTfLiteOk; } @@ -536,9 +582,9 @@ TfLiteStatus EvalHybrid( const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, - TfLiteTensor* output_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { + TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -639,15 +685,15 @@ TfLiteStatus EvalHybrid( const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; // Temporary storage for quantized values and scaling factors. int8_t* quantized_input_ptr = reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_output_state_ptr = - reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast(activation_state_quantized->data.uint8); int8_t* quantized_cell_state_ptr = reinterpret_cast(cell_state_quantized->data.uint8); float* scaling_factors_ptr = scaling_factors->data.f; @@ -672,14 +718,16 @@ TfLiteStatus EvalHybrid( input_gate_scratch, forget_gate_scratch, cell_scratch, output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, recovered_cell_weights_ptr, quantized_input_ptr, - quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, - cell_state_ptr, output_ptr_batch); + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -723,8 +771,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(mirkov): add a check that weights are all uint8s or all floats. @@ -738,11 +789,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { cell_to_output_weights, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, - scratch_buffer, output_state, cell_state, output); + scratch_buffer, activation_state, cell_state, output); } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* output_state_quantized = + TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); TfLiteTensor* cell_state_quantized = GetTemporary(context, node, /*index=*/3); @@ -760,8 +811,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, output_state_quantized, cell_state_quantized, - output_state, cell_state, output); + input_quantized, activation_state_quantized, cell_state_quantized, + activation_state, cell_state, output); } default: context->ReportError(context, "Type %d is not currently supported.", diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 6da29a4a92..3f5c44a63e 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -97,6 +97,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -227,6 +233,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index bcad58406a..1c728a4733 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index d23ec201b4..9156917140 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -32,8 +32,8 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t, {}); +int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { + int id = AddTensor(t, {}, is_variable); inputs_.push_back(id); return id; } @@ -120,6 +120,7 @@ void SingleOpModel::BuildInterpreter( CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; + interpreter_->ResetVariableTensorsToZero(); } void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index db80c0082c..6dcece4af6 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -126,8 +126,10 @@ class SingleOpModel { SingleOpModel& operator=(const SingleOpModel&) = delete; // Add a TensorType input tensor and return its index. - int AddInput(TensorType type) { return AddInput(TensorData{type}); } - int AddInput(const TensorData& t); + int AddInput(TensorType type, bool is_variable = false) { + return AddInput(TensorData{type}, is_variable); + } + int AddInput(const TensorData& t, bool is_variable = false); // Templated version of AddConstInput(). template @@ -260,7 +262,8 @@ class SingleOpModel { } template - int AddTensor(TensorData t, std::initializer_list data) { + int AddTensor(TensorData t, std::initializer_list data, + bool is_variable = false) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -309,7 +312,7 @@ class SingleOpModel { tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); + /*name=*/0, q_params, is_variable)); tensor_data_[id] = t; diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 54edfdfb1d..4d08fb5458 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -288,8 +288,8 @@ void TfLiteDriver::ResetLSTMStateTensors() { interpreter_->ResetVariableTensorsToZero(); // Below is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Refactoring and find a better way to initialize state - // tensors. Maybe write the reset instructions into the test data. + // TODO(ycling): Remove the code below after nobody is using the 18-inputs + // definition. for (auto node_index : interpreter_->execution_plan()) { const auto& node_and_reg = interpreter_->node_and_registration(node_index); const auto& node = node_and_reg->first; @@ -299,7 +299,7 @@ void TfLiteDriver::ResetLSTMStateTensors() { const auto* params = reinterpret_cast(node.builtin_data); if (params->kernel_type == kTfLiteLSTMFullKernel && - node.outputs->size >= 2) { + node.inputs->size == 18 && node.outputs->size >= 2) { // The first 2 outputs of LSTM are state tensors. for (int i = 0; i < 2; ++i) { int node_index = node.outputs->data[i]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index e6e3dfa1de..46d1fce50e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -74,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { lstm_cell_op->inputs[kInputTensor] = curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + // Previous states. + lstm_cell_op->inputs[kInputActivationStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + lstm_cell_op->inputs[kInputCellStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT]; + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. Array& kernel = model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); @@ -160,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Erase curr lstm op being replaced. DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], - model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], - model); model->operators.erase(FindOp(*model, curr_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h index 1c32a78169..6d8603a113 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs { kOutputGateBiasTensor = 15, kProjectionWeightsTensor = 16, // Optional kProjectionBiasTensor = 17, // Optional - kExtendedLstmInputCount = 18 + kInputActivationStateTensor = 18, + // The op can handle 18 inputs or 20 inputs. + kInputCellStateTensor = 19, + kExtendedLstmInputCount = 20, }; enum ExtendedLstmCellOutputs { + // TODO(ycling): Make the 2 output state tensors optional. kOutputStateTensor = 0, kCellStateTensor = 1, kOutputTensor = 2, diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index e1025c6664..a02f90988b 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ ":types", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 669fb9fa08..c93c0a6b90 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/operator.h" +// TODO(ycling): Consider refactoring to extract the LSTM definition out of +// graph_transformation module. +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" @@ -673,18 +676,20 @@ class Lstm : public BuiltinOperator(op); + std::vector mutating_input_variables(op.inputs.size(), false); switch (lstm_op.kernel_type) { - case LstmCellOperator::KERNEL_FULL: - // TODO(ycling): Change the full kernel to use the new variable tensor - // design. This requires moving the state tensors from output to input. - return std::vector(); + case LstmCellOperator::KERNEL_FULL: { + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + break; + } case LstmCellOperator::KERNEL_BASIC: { - std::vector mutating_input_variables(op.inputs.size(), false); mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; - return mutating_input_variables; + break; } } + return mutating_input_variables; } }; -- cgit v1.2.3