diff options
author | David Majnemer <majnemer@google.com> | 2018-09-04 11:17:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 11:41:30 -0700 |
commit | 5d183ab7fc7b82f1dea0b9fa9c6412c39ade15a1 (patch) | |
tree | 79a4f6fcf270617fc56082702b0209240425ae8c /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | 9ae8214229960c634c9f82c00f2c0df287c27a9d (diff) |
[XLA] Make kConvolution, kDot HLO attributes mandatory
HLO transformations would forget to propagate the feature depth attribute.
Making these attributes mandatory, while slightly less convenient for tests,
makes HLO transformations more robust.
PiperOrigin-RevId: 211490160
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 7c078f07d7..3d18fe3be2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -950,9 +950,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( new_dot_rhs = rhs_slice; } - auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); - new_dot->set_precision_config(dot.precision_config()); + auto* new_dot = computation_->AddInstruction( + HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, + new_dot_dnums, dot.precision_config())); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1053,9 +1053,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); - auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( - memoized_shape, left_operand, right_operand, dnums)); - memoized_inst->set_precision_config(dot->precision_config()); + auto* memoized_inst = computation_->AddInstruction( + HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, + dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1151,9 +1151,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), - rhs->mutable_operand(0), lhs->mutable_operand(0), - dot_dimension_numbers)); - new_dot->set_precision_config(dot->precision_config()); + rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, + dot->precision_config())); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -2477,8 +2476,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); - dot->set_precision_config(convolution->precision_config()); + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, + convolution->precision_config())); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } |