aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc20
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
index ce825c91af..69209b8dec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
@@ -24,20 +24,25 @@ limitations under the License.
namespace toco {
-bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveGatherAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto* gather_op = model->operators[op_index].get();
- if (gather_op->type != OperatorType::kGather) return false;
+ if (gather_op->type != OperatorType::kGather)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<GatherOperator*>(gather_op);
if (op->axis) {
// Attributes already resolved
- return false;
+ return ::tensorflow::Status::OK();
}
- if (op->inputs.size() != 3) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (op->inputs.size() != 3) return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[2]))
+ return ::tensorflow::Status::OK();
const auto& indices_array = model->GetArray(op->inputs[2]);
- if (!indices_array.has_shape()) return false;
+ if (!indices_array.has_shape()) return ::tensorflow::Status::OK();
const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
CHECK_EQ(axis_data.size(), 1)
<< "Multidimensional gather not supported on " << LogName(*op);
@@ -47,7 +52,8 @@ bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUsedOnce(op->inputs[2], model);
op->inputs.resize(2);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco