aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc21
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
index e880a3f44d..ab1e0bd7a0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
@@ -27,11 +27,14 @@ namespace toco {
// This implementation is looking strictly for all-or-nothing on the select
// condition. It's possible to enhance this by looking per-element and possibly
// producing a Mul op.
-bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantSelect::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kSelect) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const SelectOperator*>(base_op);
@@ -40,23 +43,23 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
// We require the cond input to be constant.
if (!IsConstantParameterArray(*model, op->inputs[0])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& cond_array = model->GetArray(op->inputs[0]);
CHECK(cond_array.data_type == ArrayDataType::kBool)
<< "Only bool conditions are supported";
const auto& cond_data = cond_array.GetBuffer<ArrayDataType::kBool>().data;
if (cond_data.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Check if the condition is the same for all elements.
@@ -67,12 +70,14 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
"Cannot resolve %s as constant; cond_array has differing "
"per-element values",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
// Pass-through the selected input.
- return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
+ *modified =
+ RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
+ return ::tensorflow::Status::OK();
}
} // namespace toco