aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-05 01:00:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 01:03:45 -0700
commitc7bd1589d08e84ca215b3c8c4dc3023986522ef7 (patch)
tree175a6f18ada9b911a482700e592728246db25ece /tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
parent568b763776b7890570d9f6ab9568153329079958 (diff)
Add support for grouped convolutions to the HloEvaluator.
Add a missing check to InferConvolveShape(), the output feature dimension needs to be divisible by feature_group_count. Also fix some tests which took a const reference to the return value of a function which doesn't return a reference. PiperOrigin-RevId: 211592011
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h36
1 files changed, 34 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index dc16a84246..6a09bb08f4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1047,9 +1047,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto rhs_literal_data = rhs_literal.data<ReturnT>();
+ int64 feature_group_count = conv->feature_group_count();
+
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
- rhs_literal_data](absl::Span<const int64> out_index) {
+ rhs_literal_data,
+ feature_group_count](absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1061,6 +1064,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_z_dim = dnums.output_feature_dimension();
const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 output_z_size =
+ ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1069,6 +1074,33 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
+ int64 rhs_iz = iz;
+ // Handle grouped convolutions.
+ if (feature_group_count > 1) {
+ // The size of a feature group.
+ int64 feature_group_size = z_size / feature_group_count;
+ rhs_iz = iz % feature_group_size;
+
+ // The output feature dimension is a concatenation of convolution
+ // results from the different groups.
+ int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current input feature
+ // index belongs.
+ int64 input_group_index = iz / feature_group_size;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ int64 output_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
+ if (input_group_index != output_group_index) {
+ // If the current output index does not belong to the current
+ // feature group, skip it.
+ continue;
+ }
+ }
+
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *
lhs_dim_multipliers[input_batch_dim];
@@ -1077,7 +1109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
int64 rhs_linear_index = 0;
rhs_linear_index += out_index[output_z_dim] *
rhs_dim_multipliers[kernel_output_z_dim];
- rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
+ rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {