aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index 8e66323bd7..e6e3dfa1de 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
return false;
}
- // Already an extended LstmCell with kExtendedLstmInputCount of inputs,
- // do not need to split cell inputs.
- if (curr_op->inputs.size() == kExtendedLstmInputCount) {
+ const auto* curr_lstm_op = static_cast<LstmCellOperator*>(curr_op);
+ // Already an extended LstmCell. Do not need to split cell inputs.
+ if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC ||
+ curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) {
return false;
}
@@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL;
lstm_cell_op->inputs.resize(kExtendedLstmInputCount);
int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT])
.shape()