diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc | 108 |
1 files changed, 26 insertions, 82 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index 3ca7f53512..c0b014b45e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -35,6 +35,26 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator( return it; } +bool ValidateSourceOp(const Model& model, const string& array_name, + OperatorType op_type, Operator** source_op) { + if (op_type == OperatorType::kNone) { + CHECK(!source_op); + } else { + CHECK(source_op); + *source_op = GetOpWithOutput(model, array_name); + if (*source_op == nullptr) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((*source_op)->type != op_type) { + return false; + } + } + + return true; +} + // Returns true if the given operator has exactly 1 input, and is connected to // the given op_type. // We use kNone to indicate an input unattached to an operator output. Usually @@ -47,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((op_type != OperatorType::kNone) && (x == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) { return false; } - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (connected_op) { - *connected_op = x; - } - return true; } @@ -81,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } return true; } @@ -132,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { - return false; - } - - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } // Check if third input is disconnected/connected to an operator - Operator* z = GetOpWithOutput(model, op.inputs[2]); - if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { - return false; - } - if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) { return false; } - // Check that third operator, if connected, is of correct type - if ((z != nullptr) && (z->type != c_op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } - if (c_op != nullptr) { - *c_op = z; - } return true; } |