diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc index 88d06d7dc7..db0fbba528 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc @@ -59,11 +59,14 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) { return true; } -bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* base_op = it->get(); if (base_op->type != OperatorType::kRandomUniform) { - return false; + return ::tensorflow::Status::OK(); } auto* op = static_cast<RandomUniformOperator*>(base_op); @@ -73,12 +76,12 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } if ((op->seed == 0) && (op->seed2 == 0)) { @@ -86,13 +89,13 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { << "\" is truly random (using /dev/random system entropy). " "Therefore, cannot resolve as constant. Set \"seed\" or " "\"seed2\" attr non-zero to fix this"; - return false; + return ::tensorflow::Status::OK(); } switch (output_array.data_type) { case ArrayDataType::kFloat: if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; // For future support of double or half. @@ -110,7 +113,8 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |