aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h6
2 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index e6e3dfa1de..46d1fce50e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -74,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
lstm_cell_op->inputs[kInputTensor] =
curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT];
+ // Previous states.
+ lstm_cell_op->inputs[kInputActivationStateTensor] =
+ curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT];
+ lstm_cell_op->inputs[kInputCellStateTensor] =
+ curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT];
+
// Get original weight tensor and decompose 1 tensor to 8 sub tensors.
Array& kernel =
model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
@@ -160,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Erase curr lstm op being replaced.
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model);
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
- DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT],
- model);
- DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT],
- model);
model->operators.erase(FindOp(*model, curr_op));
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
index 1c32a78169..6d8603a113 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
@@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs {
kOutputGateBiasTensor = 15,
kProjectionWeightsTensor = 16, // Optional
kProjectionBiasTensor = 17, // Optional
- kExtendedLstmInputCount = 18
+ kInputActivationStateTensor = 18,
+ // The op can handle 18 inputs or 20 inputs.
+ kInputCellStateTensor = 19,
+ kExtendedLstmInputCount = 20,
};
enum ExtendedLstmCellOutputs {
+ // TODO(ycling): Make the 2 output state tensors optional.
kOutputStateTensor = 0,
kCellStateTensor = 1,
kOutputTensor = 2,