diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc index e880a3f44d..ab1e0bd7a0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc @@ -27,11 +27,14 @@ namespace toco { // This implementation is looking strictly for all-or-nothing on the select // condition. It's possible to enhance this by looking per-element and possibly // producing a Mul op. -bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantSelect::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kSelect) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast<const SelectOperator*>(base_op); @@ -40,23 +43,23 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } // We require the cond input to be constant. if (!IsConstantParameterArray(*model, op->inputs[0])) { - return false; + return ::tensorflow::Status::OK(); } const Array& cond_array = model->GetArray(op->inputs[0]); CHECK(cond_array.data_type == ArrayDataType::kBool) << "Only bool conditions are supported"; const auto& cond_data = cond_array.GetBuffer<ArrayDataType::kBool>().data; if (cond_data.empty()) { - return false; + return ::tensorflow::Status::OK(); } // Check if the condition is the same for all elements. @@ -67,12 +70,14 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { "Cannot resolve %s as constant; cond_array has differing " "per-element values", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } } // Pass-through the selected input. - return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); + *modified = + RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); + return ::tensorflow::Status::OK(); } } // namespace toco |