aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-19 12:35:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 12:38:27 -0700
commit5fab6df2788937bee1cce3a4e8f5b9d1db7497ec (patch)
treeba18594841593a0b2a3eda55c076ca78c7bf0d4e /tensorflow/contrib
parent8f19772410ec20010e9930f9765dbd3aaeb06111 (diff)
Support Variable Tensor API in LSTM Full kernel.
TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs. PiperOrigin-RevId: 201222516
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc161
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h11
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc17
10 files changed, 158 insertions, 75 deletions
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index eb26a02455..1dda97c101 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -37,14 +37,17 @@ namespace builtin {
namespace lstm {
struct OpData {
- // Which kernel type to use. Full kernel (18-inputs) or basic kernel
- // (5-inputs).
+ // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel
+ // (5 inputs).
TfLiteLSTMKernelType kernel_type;
- // Only used by full kernel.
+
+ // These fields are only used by full kernel.
+ int activation_state_tensor_index;
+ int cell_state_tensor_index;
int scratch_tensor_index;
};
-// For full inputs kernel (18-inputs).
+// For full inputs kernel (18 or 20 inputs).
namespace full {
// Input Tensors of size {n_batch, n_input}
@@ -78,7 +81,16 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// If the node has 20 inputs, the following 2 tensors are used as state tensors.
+// These are defined as variable tensors, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 18;
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
+// * If the node has 18 inputs, these 2 tensors are used as state tensors.
+// * If the node has 20 inputs, these 2 tensors are ignored.
+// TODO(ycling): Make the 2 output state tensors optional, and propagate the
+// state to output tensors when the 2 tensors present.
constexpr int kOutputStateTensor = 0;
constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
@@ -246,10 +258,31 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- // Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ // True if the node is using input variable state tensors. It means:
+ // * The state tensors are defined as inputs. In this case it would be the
+ // 19th and 20th input tensors.
+ // * Otherwise, the output tensors are used to store states.
+ bool use_input_variable_states;
+ if (node->inputs->size == 20) {
+ use_input_variable_states = true;
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index =
+ node->inputs->data[kInputCellStateTensor];
+ } else if (node->inputs->size == 18) {
+ use_input_variable_states = false;
+ op_data->activation_state_tensor_index =
+ node->outputs->data[kOutputStateTensor];
+ op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor];
+ } else {
+ context->ReportError(
+ context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs",
+ node->inputs->size);
+ return kTfLiteError;
+ }
+
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -274,34 +307,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, output_state and cell_state tensors.
+ // Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
+ if (use_input_variable_states) {
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state),
+ n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ } else {
+ // If the state tensors are outputs, this function takes the
+ // responsibility to resize the state tensors.
+ TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2);
+ activation_state_size->data[0] = n_batch;
+ activation_state_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
+ activation_state_size));
+
+ TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
+ cell_size->data[0] = n_batch;
+ cell_size->data[1] = n_cell;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state, cell_size));
+ // Mark state tensors as persistent tensors.
+ activation_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+ }
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
-
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
@@ -337,7 +383,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (is_hybrid_op) {
// Allocate temporary tensors to store quantized values of input,
- // output_state and cell_state tensors.
+ // activation_state and cell_state tensors.
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
input_quantized->type = kTfLiteUInt8;
@@ -348,17 +394,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
input_quantized_size));
}
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
- output_state_quantized->type = kTfLiteUInt8;
- output_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(output_state_quantized->dims,
- output_state->dims)) {
- TfLiteIntArray* output_state_quantized_size =
- TfLiteIntArrayCopy(output_state->dims);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, output_state_quantized,
- output_state_quantized_size));
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
}
node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
TfLiteTensor* cell_state_quantized =
@@ -438,7 +484,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
@@ -499,7 +545,7 @@ TfLiteStatus EvalFloat(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
float* output_ptr_batch = output->data.f;
@@ -512,8 +558,8 @@ TfLiteStatus EvalFloat(
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, output_ptr_batch);
+ activation_state_ptr, cell_state_ptr, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
return kTfLiteOk;
}
@@ -536,9 +582,9 @@ TfLiteStatus EvalHybrid(
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
+ TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
// n_cell and n_output will be the same size when there is no projection.
@@ -639,15 +685,15 @@ TfLiteStatus EvalHybrid(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
float* output_ptr_batch = output->data.f;
// Temporary storage for quantized values and scaling factors.
int8_t* quantized_input_ptr =
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
int8_t* quantized_cell_state_ptr =
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
@@ -672,14 +718,16 @@ TfLiteStatus EvalHybrid(
input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr,
- cell_state_ptr, output_ptr_batch);
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -723,8 +771,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// TODO(mirkov): add a check that weights are all uint8s or all floats.
@@ -738,11 +789,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params,
- scratch_buffer, output_state, cell_state, output);
+ scratch_buffer, activation_state, cell_state, output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@@ -760,8 +811,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, output_state_quantized, cell_state_quantized,
- output_state, cell_state, output);
+ input_quantized, activation_state_quantized, cell_state_quantized,
+ activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 6da29a4a92..3f5c44a63e 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -97,6 +97,12 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -227,6 +233,8 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
int output_;
int output_state_;
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index bcad58406a..1c728a4733 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
int output_;
int output_state_;
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index d23ec201b4..9156917140 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -32,8 +32,8 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
return matchers;
}
-int SingleOpModel::AddInput(const TensorData& t) {
- int id = AddTensor<float>(t, {});
+int SingleOpModel::AddInput(const TensorData& t, bool is_variable) {
+ int id = AddTensor<float>(t, {}, is_variable);
inputs_.push_back(id);
return id;
}
@@ -120,6 +120,7 @@ void SingleOpModel::BuildInterpreter(
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
+ interpreter_->ResetVariableTensorsToZero();
}
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index db80c0082c..6dcece4af6 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -126,8 +126,10 @@ class SingleOpModel {
SingleOpModel& operator=(const SingleOpModel&) = delete;
// Add a TensorType input tensor and return its index.
- int AddInput(TensorType type) { return AddInput(TensorData{type}); }
- int AddInput(const TensorData& t);
+ int AddInput(TensorType type, bool is_variable = false) {
+ return AddInput(TensorData{type}, is_variable);
+ }
+ int AddInput(const TensorData& t, bool is_variable = false);
// Templated version of AddConstInput().
template <typename T>
@@ -260,7 +262,8 @@ class SingleOpModel {
}
template <typename T>
- int AddTensor(TensorData t, std::initializer_list<T> data) {
+ int AddTensor(TensorData t, std::initializer_list<T> data,
+ bool is_variable = false) {
int id = tensors_.size();
// This is slightly different depending on whether we are adding a
@@ -309,7 +312,7 @@ class SingleOpModel {
tensors_.push_back(CreateTensor(builder_,
builder_.CreateVector<int>(t.shape), t.type,
/*buffer=*/buffer_id,
- /*name=*/0, q_params));
+ /*name=*/0, q_params, is_variable));
tensor_data_[id] = t;
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 54edfdfb1d..4d08fb5458 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -288,8 +288,8 @@ void TfLiteDriver::ResetLSTMStateTensors() {
interpreter_->ResetVariableTensorsToZero();
// Below is a workaround for initializing state tensors for LSTM.
- // TODO(ycling): Refactoring and find a better way to initialize state
- // tensors. Maybe write the reset instructions into the test data.
+ // TODO(ycling): Remove the code below after nobody is using the 18-inputs
+ // definition.
for (auto node_index : interpreter_->execution_plan()) {
const auto& node_and_reg = interpreter_->node_and_registration(node_index);
const auto& node = node_and_reg->first;
@@ -299,7 +299,7 @@ void TfLiteDriver::ResetLSTMStateTensors() {
const auto* params =
reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
if (params->kernel_type == kTfLiteLSTMFullKernel &&
- node.outputs->size >= 2) {
+ node.inputs->size == 18 && node.outputs->size >= 2) {
// The first 2 outputs of LSTM are state tensors.
for (int i = 0; i < 2; ++i) {
int node_index = node.outputs->data[i];
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index e6e3dfa1de..46d1fce50e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -74,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
lstm_cell_op->inputs[kInputTensor] =
curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT];
+ // Previous states.
+ lstm_cell_op->inputs[kInputActivationStateTensor] =
+ curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT];
+ lstm_cell_op->inputs[kInputCellStateTensor] =
+ curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT];
+
// Get original weight tensor and decompose 1 tensor to 8 sub tensors.
Array& kernel =
model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
@@ -160,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Erase curr lstm op being replaced.
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model);
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
- DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT],
- model);
- DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT],
- model);
model->operators.erase(FindOp(*model, curr_op));
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
index 1c32a78169..6d8603a113 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
@@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs {
kOutputGateBiasTensor = 15,
kProjectionWeightsTensor = 16, // Optional
kProjectionBiasTensor = 17, // Optional
- kExtendedLstmInputCount = 18
+ kInputActivationStateTensor = 18,
+ // The op can handle 18 inputs or 20 inputs.
+ kInputCellStateTensor = 19,
+ kExtendedLstmInputCount = 20,
};
enum ExtendedLstmCellOutputs {
+ // TODO(ycling): Make the 2 output state tensors optional.
kOutputStateTensor = 0,
kCellStateTensor = 1,
kOutputTensor = 2,
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index e1025c6664..a02f90988b 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -24,6 +24,7 @@ cc_library(
deps = [
":types",
"//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 669fb9fa08..c93c0a6b90 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+// TODO(ycling): Consider refactoring to extract the LSTM definition out of
+// graph_transformation module.
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
@@ -673,18 +676,20 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
const Operator& op) const override {
const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
switch (lstm_op.kernel_type) {
- case LstmCellOperator::KERNEL_FULL:
- // TODO(ycling): Change the full kernel to use the new variable tensor
- // design. This requires moving the state tensors from output to input.
- return std::vector<bool>();
+ case LstmCellOperator::KERNEL_FULL: {
+ mutating_input_variables[kInputActivationStateTensor] = true;
+ mutating_input_variables[kInputCellStateTensor] = true;
+ break;
+ }
case LstmCellOperator::KERNEL_BASIC: {
- std::vector<bool> mutating_input_variables(op.inputs.size(), false);
mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
- return mutating_input_variables;
+ break;
}
}
+ return mutating_input_variables;
}
};