aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc26
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
index f6f95481b5..5400d395ff 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
@@ -41,11 +41,14 @@ bool ComputeFillArray(Model* model, FillOperator* op) {
return true;
}
-bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantFill::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto fill_it = model->operators.begin() + op_index;
auto* base_op = fill_it->get();
if (base_op->type != OperatorType::kFill) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* op = static_cast<FillOperator*>(base_op);
@@ -55,44 +58,44 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& val_array = model->GetArray(op->inputs[1]);
if (!val_array.has_shape()) {
// Yield until the value shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[1])) {
// Yield until the value is constant.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kUint8:
if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt32:
if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt64:
if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
default:
@@ -114,7 +117,8 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(fill_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco