diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index c4a1da28f3..0d3045213d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -253,6 +253,14 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers); + // Helper that infers the shape of the tensor produced by a gather operation + // with the given input shape, gather indices shape and gather dimension + // numbers. + static StatusOr<Shape> InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. |