diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-08-16 01:32:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 01:36:41 -0700 |
commit | 72b829dcca2d1acaeea130e580ce780b1a7d550a (patch) | |
tree | 6c7e26f84f8d7eb5eeaf7f802db716b931757df7 /tensorflow/compiler/xla/service/shape_inference.h | |
parent | 9d97b34bde77762a7499306ee74a56bcc91a95dc (diff) |
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a
generalization of depthwise convolution.
PiperOrigin-RevId: 208950311
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.h')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index c185b0a1bd..bfd79a4433 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -112,7 +112,8 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr<Shape> InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Infers the shape produced by the given FFT type on the given operand. static StatusOr<Shape> InferFftShape( |