aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-18 12:55:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 12:59:31 -0700
commitf1a1496148ea8a828e37201b8d0ab5d7e4979a1a (patch)
treef33c2a456b66af0204833e4e77c04cd3eda1a7eb /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent9fd62257907ed2df20bd89f9346dbe91ada1fbbc (diff)
Fixing gather to support round-tripping non-zero axis.
PiperOrigin-RevId: 205122011
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc19
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index a250db9975..4275ee9a03 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1043,17 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
return;
}
+ // Yield until the axis has been resolved.
+ if (!op->axis) {
+ return;
+ }
+ int axis = op->axis.value();
+
const auto& input_shape = input_array.shape();
const auto& indices_shape = indices_array.shape();
QCHECK_GE(input_shape.dimensions_count(), 1);
op->input_rank = input_shape.dimensions_count();
+ QCHECK_LT(axis, op->input_rank);
- // Copy the input dimensions to the output except for dimension 0,
+ // Copy the input dimensions to the output except for the axis dimensions
// where the dimension of indices_shape is used.
- // TODO(mgubin): if axis != 0 this is not true, change when it's supported.
auto output_dims = output_array.mutable_shape()->mutable_dims();
- output_dims->push_back(indices_shape.dims(0));
- for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
+ for (int dim = 0; dim < axis; ++dim) {
+ output_dims->push_back(input_shape.dims(dim));
+ }
+ for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
+ output_dims->push_back(indices_shape.dims(dim));
+ }
+ for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
output_dims->push_back(input_shape.dims(dim));
}
}