aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-08-16 01:32:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 01:36:41 -0700
commit72b829dcca2d1acaeea130e580ce780b1a7d550a (patch)
tree6c7e26f84f8d7eb5eeaf7f802db716b931757df7 /tensorflow/compiler/xla/service/shape_inference.cc
parent9d97b34bde77762a7499306ee74a56bcc91a95dc (diff)
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a generalization of depthwise convolution. PiperOrigin-RevId: 208950311
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc9
1 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a4ea2b28f4..ec5743a777 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1530,7 +1530,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
@@ -1640,12 +1640,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 kernel_output_features =
rhs.dimensions(dnums.kernel_output_feature_dimension());
- if (input_features != kernel_input_features) {
+ if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %lld) to match RHS "
- "input feature dimension (value %lld); got <conv>(%s, %s)\n"
+ "input feature dimension * feature_group_count (value %lld); "
+ "got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features,
+ input_features, kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
}