aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc24
1 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index 9f5d8b9450..fc49fbda59 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -48,20 +48,26 @@ void RerouteEdges(const string& from_array, const string& to_array,
} // namespace
bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
- Model* model, std::size_t op_index) {
+ Model* model, std::size_t op_index,
+ int input_index) {
const auto passthru_it = model->operators.begin() + op_index;
auto* passthru_op = passthru_it->get();
CHECK_EQ(passthru_op->outputs.size(), 1);
CHECK_GE(passthru_op->inputs.size(), 1);
- int count_nonconstant_input_arrays = 0;
- // We call 'main input' the unique nonconstant input array if there is one,
- // or else the 0-th input.
+
int main_input_array_index = 0;
- for (int i = 0; i < passthru_op->inputs.size(); i++) {
- if (!model->GetArray(passthru_op->inputs[i]).buffer) {
- count_nonconstant_input_arrays++;
- if (count_nonconstant_input_arrays == 1) {
- main_input_array_index = i;
+ if (input_index != -1) {
+ main_input_array_index = input_index;
+ } else {
+ // We call 'main input' the unique nonconstant input array if there is one,
+ // or else the 0-th input.
+ int count_nonconstant_input_arrays = 0;
+ for (int i = 0; i < passthru_op->inputs.size(); i++) {
+ if (!model->GetArray(passthru_op->inputs[i]).buffer) {
+ count_nonconstant_input_arrays++;
+ if (count_nonconstant_input_arrays == 1) {
+ main_input_array_index = i;
+ }
}
}
}