diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc index 36d7dad0ce..6e3a6a69c2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -61,11 +61,14 @@ inline void Gather(const Array& input_array, int input_rank, // Resolves a constant Gather operation. // This simply performs the gather and produces the output array with the // appropriate values. -bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantGather::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kGather) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast<const GatherOperator*>(base_op); @@ -74,28 +77,28 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } if (!op->axis) { // Yield until axis has been set by ResolveGatherAttributes. - return false; + return ::tensorflow::Status::OK(); } if (op->axis.value() != 0) { // Only handling axis=0 for now. AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), op->axis.value()); - return false; + return ::tensorflow::Status::OK(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(op->inputs[0]); const Array& coords_array = model->GetArray(op->inputs[1]); @@ -142,7 +145,8 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { // Erase the operator. model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |