aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-10-02 22:39:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 22:44:19 -0700
commitbbe15eee6779941c54e145d12e16f6473738857c (patch)
tree80b87b5fa9334b4a0d0a1a0159cf0d43956df5c5 /tensorflow/compiler/xla
parent65b5190065db0074f8722b09ba43423438c40258 (diff)
[XLA] Modify the function that determines whether an instruction can change
layout so that it can be used by the HLO verifier. Change the function to a static member function of the LayoutAssignment class. Add an std::function member to LayoutAssignment to store the function object passed down from the backend compiler class and use it to decide whether an instruction can change layouts. Fix affected test cases. PiperOrigin-RevId: 215515611
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h5
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc18
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h18
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc3
10 files changed, 59 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 18fc144efe..ea8c200dee 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -308,7 +308,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), target_machine_features);
+ module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
return pipeline.Run(module).status();
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
index 3c4fe68b83..f4da35dd37 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h
@@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment {
public:
explicit CpuLayoutAssignment(
ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
const TargetMachineFeatures* target_machine_features)
- : LayoutAssignment(entry_computation_layout),
+ : LayoutAssignment(entry_computation_layout,
+ std::move(instruction_can_change_layout_func)),
target_machine_features_(*target_machine_features) {}
~CpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 4668f3872d..97659b88a7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase {
[](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
- cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout,
- &target_machine_features);
+ cpu::CpuLayoutAssignment layout_assignment(
+ entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ &target_machine_features);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
};
@@ -321,8 +322,9 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
[](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
});
- cpu::CpuLayoutAssignment layout_assignment(&computation_layout,
- &target_machine_features);
+ cpu::CpuLayoutAssignment layout_assignment(
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ &target_machine_features);
TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something,
layout_assignment.Run(module));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index e2b96a81d4..4ba7989e9c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -30,8 +30,11 @@ namespace gpu {
class GpuLayoutAssignment : public LayoutAssignment {
public:
explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
se::StreamExecutor* stream_executor)
- : LayoutAssignment(entry_computation_layout),
+ : LayoutAssignment(entry_computation_layout,
+ std::move(instruction_can_change_layout_func)),
stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index fbc8ddf599..04681cfcec 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
ShapeLayout(result_shape_with_layout);
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
@@ -348,8 +352,9 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- GpuLayoutAssignment layout_assignment(&computation_layout,
- backend().default_stream_executor());
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape =
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 0b3b429710..b4ae2e42c7 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -232,7 +232,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// a layout-sensitive verifier!
HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_entry_computation_layout(), stream_exec);
+ hlo_module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout, stream_exec);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index bb69cb9c47..27fe89375d 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_entry_computation_layout());
+ hlo_module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout);
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 25d5327561..68a08a0886 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -974,10 +974,15 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
+
saved_entry_computation_layout_(*entry_computation_layout),
- channel_layout_constraints_(channel_constraints) {
+ channel_layout_constraints_(channel_constraints),
+ instruction_can_change_layout_func_(
+ std::move(instruction_can_change_layout_func)) {
if (channel_layout_constraints_ != nullptr) {
// Save a copy of the input ChannelLayoutConstraints so that we can reset it
// if we have to undo previous operations (ClearPreviousPassSideEffects()).
@@ -998,7 +1003,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape()) &&
- InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) {
+ !instruction_can_change_layout_func_(instruction)) {
// Propagate the result layout to the operand layout if the instruction
// requires the same layout out for the result and the operand.
//
@@ -1076,7 +1081,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
if (!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) &&
- InstructionRequiresInputLayoutEqualToOutputLayout(user)) {
+ !instruction_can_change_layout_func_(user)) {
// Assign users the same layout as the operand.
return absl::make_unique<Layout>(operand_layout);
}
@@ -1842,7 +1847,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
return true;
}
-bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
+/* static */
+bool LayoutAssignment::InstructionCanChangeLayout(
const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kAbs:
@@ -1908,7 +1914,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
case HloOpcode::kTanh:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
- return true;
+ return false;
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
@@ -1939,7 +1945,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout(
case HloOpcode::kTrace:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
- return false;
+ return true;
}
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 15f0adcaaf..2d48e12263 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -286,6 +286,11 @@ class LayoutAssignment : public HloModulePass {
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
//
+ // instruction_can_change_layout_func is a function object that determines
+ // whether an instruction can change layouts. An instruction not being able to
+ // change layout means that it requires operands with the same rank as the
+ // output to have the same layout as the output.
+ //
// channel_constraints is both an input and output. Any sends or recvs that
// are present in channel_constraints will be laid out as constrained. Any
// unconstrained sends or recvs will be laid out as locally optimal and their
@@ -295,6 +300,8 @@ class LayoutAssignment : public HloModulePass {
// within any module passed to `Run`.
explicit LayoutAssignment(
ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = InstructionCanChangeLayout,
ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment() override {}
absl::string_view name() const override { return "layout-assignment"; }
@@ -303,10 +310,10 @@ class LayoutAssignment : public HloModulePass {
// (any layouts were changed).
StatusOr<bool> Run(HloModule* module) override;
- // Returns true if the instruction requires that operands with the same rank
- // as the output have to have the same layout as the output.
- virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
- const HloInstruction* instruction);
+ // Determines whether an instruction can change layouts. An instruction not
+ // being able to change layout means that it requires operands with the same
+ // rank as the output to have the same layout as the output.
+ static bool InstructionCanChangeLayout(const HloInstruction* instruction);
protected:
// These methods, invoked by PropagateConstraints, propagate a layout
@@ -522,6 +529,9 @@ class LayoutAssignment : public HloModulePass {
// The set of HLO instructions which lacked any layout constraint, thus
// receiving propagated default layouts.
absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
+
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 10f9a95121..15c16d667c 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -55,7 +55,8 @@ class LayoutAssignmentTest : public HloVerifiedTestBase {
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr) {
LayoutAssignment layout_assignment(
- entry_computation_layout, /*channel_constraints=*/channel_constraints);
+ entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ /*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}