diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-18 12:55:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 12:59:31 -0700 |
commit | f1a1496148ea8a828e37201b8d0ab5d7e4979a1a (patch) | |
tree | f33c2a456b66af0204833e4e77c04cd3eda1a7eb /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 9fd62257907ed2df20bd89f9346dbe91ada1fbbc (diff) |
Fixing gather to support round-tripping non-zero axis.
PiperOrigin-RevId: 205122011
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index a250db9975..4275ee9a03 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1043,17 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { return; } + // Yield until the axis has been resolved. + if (!op->axis) { + return; + } + int axis = op->axis.value(); + const auto& input_shape = input_array.shape(); const auto& indices_shape = indices_array.shape(); QCHECK_GE(input_shape.dimensions_count(), 1); op->input_rank = input_shape.dimensions_count(); + QCHECK_LT(axis, op->input_rank); - // Copy the input dimensions to the output except for dimension 0, + // Copy the input dimensions to the output except for the axis dimensions // where the dimension of indices_shape is used. - // TODO(mgubin): if axis != 0 this is not true, change when it's supported. auto output_dims = output_array.mutable_shape()->mutable_dims(); - output_dims->push_back(indices_shape.dims(0)); - for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + for (int dim = 0; dim < axis; ++dim) { + output_dims->push_back(input_shape.dims(dim)); + } + for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) { + output_dims->push_back(indices_shape.dims(dim)); + } + for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) { output_dims->push_back(input_shape.dims(dim)); } } |