aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc20
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
index 36d7dad0ce..6e3a6a69c2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
@@ -61,11 +61,14 @@ inline void Gather(const Array& input_array, int input_rank,
// Resolves a constant Gather operation.
// This simply performs the gather and produces the output array with the
// appropriate values.
-bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantGather::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kGather) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const GatherOperator*>(base_op);
@@ -74,28 +77,28 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!op->axis) {
// Yield until axis has been set by ResolveGatherAttributes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->axis.value() != 0) {
// Only handling axis=0 for now.
AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op),
op->axis.value());
- return false;
+ return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
const Array& coords_array = model->GetArray(op->inputs[1]);
@@ -142,7 +145,8 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco