diff options
author | 2018-08-16 01:32:25 -0700 | |
---|---|---|
committer | 2018-08-16 01:36:41 -0700 | |
commit | 72b829dcca2d1acaeea130e580ce780b1a7d550a (patch) | |
tree | 6c7e26f84f8d7eb5eeaf7f802db716b931757df7 /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | 9d97b34bde77762a7499306ee74a56bcc91a95dc (diff) |
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a
generalization of depthwise convolution.
PiperOrigin-RevId: 208950311
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a4ea2b28f4..ec5743a777 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1530,7 +1530,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1640,12 +1640,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features) { + if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension (value %lld); got <conv>(%s, %s)\n" + "input feature dimension * feature_group_count (value %lld); " + "got <conv>(%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, + input_features, kernel_input_features * feature_group_count, ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); } |