diff options
author | 2018-06-11 10:26:40 -0700 | |
---|---|---|
committer | 2018-06-11 10:29:43 -0700 | |
commit | 7b8c64ef05c7fdddb3f3a32fd3189e1e4b7e8985 (patch) | |
tree | 7ec426f6943994f63d3050a16cfe3d147776a6aa | |
parent | 59259fd74a7cdf766b54e1de00abae88438d1978 (diff) |
Remove dead code to use a map in BatchnormExpander
PiperOrigin-RevId: 200072055
4 files changed, 12 insertions, 98 deletions
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index a9f4aead59..ec13fadbc7 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_map_instructions); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,22 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, - bool use_map_instructions) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_map_instructions_(use_map_instructions) {} + rewrite_grad_op_(rewrite_grad_op) {} HloComputation* GetOrCreateScalarAddComputation( PrimitiveType primitive_type) { - HloComputation** scalar_add_computation = - &scalar_add_computations_[primitive_type]; - if (*scalar_add_computation) { - return *scalar_add_computation; - } - HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(primitive_type, {}); auto scalar_lhs = b.AddInstruction( @@ -94,44 +85,13 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - *scalar_add_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_add_computation; - } - - // TODO(b/80534766): Remove maps after performance issues with scalar - // broadcasts are resolved on all backends. - HloComputation* GetOrCreateScalarRsqrtComputation( - PrimitiveType primitive_type) { - HloComputation** scalar_rsqrt_computation = - &scalar_rsqrt_computations_[primitive_type]; - if (*scalar_rsqrt_computation) { - return *scalar_rsqrt_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(primitive_type, {}); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( - shape, b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0<float>(-0.5f))))); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kPower, scalar_lhs, scalar_rhs)); - *scalar_rsqrt_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_rsqrt_computation; + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } std::unique_ptr<HloInstruction> Rsqrt( HloInstruction* operand, const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& add_instruction) { - if (use_map_instructions_) { - return HloInstruction::CreateMap( - operand->shape(), {operand}, - GetOrCreateScalarRsqrtComputation(operand->shape().element_type())); - } HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( operand->shape(), add_instruction(HloInstruction::CreateConvert( @@ -143,40 +103,10 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { operand, exponent); } - HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type, - int64 element_count) { - HloComputation** scalar_mean_computation = - &scalar_mean_computations_[std::pair<PrimitiveType, int64>( - primitive_type, element_count)]; - if (*scalar_mean_computation) { - return *scalar_mean_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(primitive_type, {}); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( - shape, b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>( - 1.0f / static_cast<float>(element_count)))))); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs)); - *scalar_mean_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_mean_computation; - } - std::unique_ptr<HloInstruction> Mean( int64 element_count, HloInstruction* operand, const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& add_instruction) { - if (use_map_instructions_) { - return HloInstruction::CreateMap( - operand->shape(), {operand}, - GetOrCreateScalarMeanComputation(operand->shape().element_type(), - element_count)); - } HloInstruction* elem_count_recip = add_instruction(HloInstruction::CreateBroadcast( operand->shape(), @@ -218,18 +148,9 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_map_instructions_; // Whether rewrite has occurred. bool changed_ = false; - - // Cached computations for adding two scalars. - tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*> - scalar_add_computations_; - tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*> - scalar_rsqrt_computations_; - tensorflow::gtl::FlatMap<std::pair<PrimitiveType, int64>, HloComputation*> - scalar_mean_computations_; }; } // namespace @@ -237,14 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, - bool use_map_instructions) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_map_instructions=*/use_map_instructions); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -668,8 +587,8 @@ StatusOr<bool> BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_map_instructions_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 8826636416..7ae202c583 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,12 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, - bool use_map_instructions = false) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_map_instructions_(use_map_instructions) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -48,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_map_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d6b7b7d2d8..4c0e189e78 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass<BatchNormExpander>( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_map_instructions=*/false); + /*rewrite_grad_op=*/true); pass.AddPass<AlgebraicSimplifier>( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index cc33847c5c..afefc740d7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -163,8 +163,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pass.AddPass<BatchNormExpander>( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_map_instructions=*/false); + /*rewrite_grad_op=*/true); // Rewrite gather ops into smaller ones. pass.AddPass<GatherExpander>(); |