aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
index debe298a5a..36d7dad0ce 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
@@ -69,7 +69,7 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
}
const auto* op = static_cast<const GatherOperator*>(base_op);
- CHECK_EQ(op->inputs.size(), 2);
+ CHECK_GE(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
@@ -81,10 +81,14 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
return false;
}
- // Only handling axis=0 for now.
- if (op->axis != 0) {
+ if (!op->axis) {
+ // Yield until axis has been set by ResolveGatherAttributes.
+ return false;
+ }
+ if (op->axis.value() != 0) {
+ // Only handling axis=0 for now.
AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op),
- op->axis);
+ op->axis.value());
return false;
}