aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-10-02 22:39:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 22:44:19 -0700
commitbbe15eee6779941c54e145d12e16f6473738857c (patch)
tree80b87b5fa9334b4a0d0a1a0159cf0d43956df5c5 /tensorflow/compiler/xla/service/gpu
parent65b5190065db0074f8722b09ba43423438c40258 (diff)
[XLA] Modify the function that determines whether an instruction can change
layout so that it can be used by the HLO verifier. Change the function to a static member function of the LayoutAssignment class. Add an std::function member to LayoutAssignment to store the function object passed down from the backend compiler class and use it to decide whether an instruction can change layouts. Fix affected test cases. PiperOrigin-RevId: 215515611
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu')
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h5
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc3
3 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index e2b96a81d4..4ba7989e9c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -30,8 +30,11 @@ namespace gpu {
class GpuLayoutAssignment : public LayoutAssignment {
public:
explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func,
se::StreamExecutor* stream_executor)
- : LayoutAssignment(entry_computation_layout),
+ : LayoutAssignment(entry_computation_layout,
+ std::move(instruction_can_change_layout_func)),
stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index fbc8ddf599..04681cfcec 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
ShapeLayout(result_shape_with_layout);
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}
GpuLayoutAssignment layout_assignment(
- &computation_layout, backend().default_stream_executor());
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
@@ -348,8 +352,9 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
- GpuLayoutAssignment layout_assignment(&computation_layout,
- backend().default_stream_executor());
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, LayoutAssignment::InstructionCanChangeLayout,
+ backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape =
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 0b3b429710..b4ae2e42c7 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -232,7 +232,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// a layout-sensitive verifier!
HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_entry_computation_layout(), stream_exec);
+ hlo_module->mutable_entry_computation_layout(),
+ LayoutAssignment::InstructionCanChangeLayout, stream_exec);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}