aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-05 01:00:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 01:03:45 -0700
commitc7bd1589d08e84ca215b3c8c4dc3023986522ef7 (patch)
tree175a6f18ada9b911a482700e592728246db25ece /tensorflow/compiler/xla/service/shape_inference.cc
parent568b763776b7890570d9f6ab9568153329079958 (diff)
Add support for grouped convolutions to the HloEvaluator.
Add a missing check to InferConvolveShape(), the output feature dimension needs to be divisible by feature_group_count. Also fix some tests which took a const reference to the return value of a function which doesn't return a reference. PiperOrigin-RevId: 211592011
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 7758a5dd4d..74bdf2a2e3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1672,6 +1672,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
+ if (kernel_output_features % feature_group_count > 0) {
+ return InvalidArgument(
+ "Expected output feature dimension (value %d) to be divisible by "
+ "feature_group_count (value %d); "
+ "got <conv>(%s, %s)\n"
+ "Dimension numbers: {%s}.",
+ kernel_output_features, feature_group_count,
+ ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+ dnums.DebugString());
+ }
std::vector<int64> window_dims(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
window_dims[i] = window.dimensions(i).size();