aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-06-11 10:26:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 10:29:43 -0700
commit7b8c64ef05c7fdddb3f3a32fd3189e1e4b7e8985 (patch)
tree7ec426f6943994f63d3050a16cfe3d147776a6aa
parent59259fd74a7cdf766b54e1de00abae88438d1978 (diff)
Remove dead code to use a map in BatchnormExpander
PiperOrigin-RevId: 200072055
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc97
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h7
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc3
4 files changed, 12 insertions, 98 deletions
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index a9f4aead59..ec13fadbc7 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
- bool rewrite_inference_op, bool rewrite_grad_op,
- bool use_map_instructions);
+ bool rewrite_inference_op, bool rewrite_grad_op);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@@ -70,22 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
explicit BatchNormExpanderVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
- bool rewrite_grad_op,
- bool use_map_instructions)
+ bool rewrite_grad_op)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
- rewrite_grad_op_(rewrite_grad_op),
- use_map_instructions_(use_map_instructions) {}
+ rewrite_grad_op_(rewrite_grad_op) {}
HloComputation* GetOrCreateScalarAddComputation(
PrimitiveType primitive_type) {
- HloComputation** scalar_add_computation =
- &scalar_add_computations_[primitive_type];
- if (*scalar_add_computation) {
- return *scalar_add_computation;
- }
-
HloComputation::Builder b("scalar_add_computation");
Shape shape = ShapeUtil::MakeShape(primitive_type, {});
auto scalar_lhs = b.AddInstruction(
@@ -94,44 +85,13 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
- *scalar_add_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_add_computation;
- }
-
- // TODO(b/80534766): Remove maps after performance issues with scalar
- // broadcasts are resolved on all backends.
- HloComputation* GetOrCreateScalarRsqrtComputation(
- PrimitiveType primitive_type) {
- HloComputation** scalar_rsqrt_computation =
- &scalar_rsqrt_computations_[primitive_type];
- if (*scalar_rsqrt_computation) {
- return *scalar_rsqrt_computation;
- }
-
- HloComputation::Builder b("scalar_add_computation");
- Shape shape = ShapeUtil::MakeShape(primitive_type, {});
- auto scalar_lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
- shape, b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-0.5f)))));
- auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kPower, scalar_lhs, scalar_rhs));
- *scalar_rsqrt_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_rsqrt_computation;
+ return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
}
std::unique_ptr<HloInstruction> Rsqrt(
HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
add_instruction) {
- if (use_map_instructions_) {
- return HloInstruction::CreateMap(
- operand->shape(), {operand},
- GetOrCreateScalarRsqrtComputation(operand->shape().element_type()));
- }
HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast(
operand->shape(),
add_instruction(HloInstruction::CreateConvert(
@@ -143,40 +103,10 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
operand, exponent);
}
- HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type,
- int64 element_count) {
- HloComputation** scalar_mean_computation =
- &scalar_mean_computations_[std::pair<PrimitiveType, int64>(
- primitive_type, element_count)];
- if (*scalar_mean_computation) {
- return *scalar_mean_computation;
- }
-
- HloComputation::Builder b("scalar_add_computation");
- Shape shape = ShapeUtil::MakeShape(primitive_type, {});
- auto scalar_lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
- shape, b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(
- 1.0f / static_cast<float>(element_count))))));
- auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs));
- *scalar_mean_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_mean_computation;
- }
-
std::unique_ptr<HloInstruction> Mean(
int64 element_count, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
add_instruction) {
- if (use_map_instructions_) {
- return HloInstruction::CreateMap(
- operand->shape(), {operand},
- GetOrCreateScalarMeanComputation(operand->shape().element_type(),
- element_count));
- }
HloInstruction* elem_count_recip =
add_instruction(HloInstruction::CreateBroadcast(
operand->shape(),
@@ -218,18 +148,9 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
- bool use_map_instructions_;
// Whether rewrite has occurred.
bool changed_ = false;
-
- // Cached computations for adding two scalars.
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_add_computations_;
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_rsqrt_computations_;
- tensorflow::gtl::FlatMap<std::pair<PrimitiveType, int64>, HloComputation*>
- scalar_mean_computations_;
};
} // namespace
@@ -237,14 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool BatchNormExpanderVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
- bool rewrite_grad_op,
- bool use_map_instructions) {
+ bool rewrite_grad_op) {
BatchNormExpanderVisitor visitor(
computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_inference_op=*/rewrite_inference_op,
- /*rewrite_grad_op=*/rewrite_grad_op,
- /*use_map_instructions=*/use_map_instructions);
+ /*rewrite_grad_op=*/rewrite_grad_op);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@@ -668,8 +587,8 @@ StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
- rewrite_inference_op_, rewrite_grad_op_,
- use_map_instructions_)) {
+ rewrite_inference_op_,
+ rewrite_grad_op_)) {
changed = true;
}
}
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 8826636416..7ae202c583 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -31,12 +31,10 @@ class BatchNormExpander : public HloPassInterface {
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
bool rewrite_inference_op = false,
- bool rewrite_grad_op = false,
- bool use_map_instructions = false)
+ bool rewrite_grad_op = false)
: rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
- rewrite_grad_op_(rewrite_grad_op),
- use_map_instructions_(use_map_instructions) {}
+ rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
@@ -48,7 +46,6 @@ class BatchNormExpander : public HloPassInterface {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
- bool use_map_instructions_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index d6b7b7d2d8..4c0e189e78 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
- /*rewrite_grad_op=*/true,
- /*use_map_instructions=*/false);
+ /*rewrite_grad_op=*/true);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; },
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index cc33847c5c..afefc740d7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -163,8 +163,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
- /*rewrite_grad_op=*/true,
- /*use_map_instructions=*/false);
+ /*rewrite_grad_op=*/true);
// Rewrite gather ops into smaller ones.
pass.AddPass<GatherExpander>();