aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-08-23 23:57:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 00:01:56 -0700
commit0a1b9ed01e48092f4167e366cf7496c2b111ef6d (patch)
tree7cdd04e96b485239bb12df4c0d69b2d4f9cfbb26
parent769f83b04126e10bcb4b18bf7ddc018c3e716e03 (diff)
[XLA] Add and use a layout-sensitive HLO verifier.
For now, this verifier checks some noncontroversial invariants, like: - Fusion operands and fusion computation parameters must have matching layouts. - Same for while loops, calls, kConditional. It's a bit of a pain to add these explicit layout-sensitive and allow-mixed-precision flags everywhere, but I think it's better than adding default args. With default args we can easily mix up the order, and we'd only be able to add new flags to the end of the list. PiperOrigin-RevId: 210059349
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc69
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc144
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h58
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h4
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h17
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc33
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc8
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h4
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc15
32 files changed, 387 insertions, 177 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index ebf21ac151..bb63ea26d4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -52,7 +52,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
-class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
+class AlgebraicSimplifierTest : public HloVerifiedTestBase {
+ public:
+ AlgebraicSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@@ -2851,7 +2856,12 @@ struct DotOfConcatTestSpec {
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
+ public ::testing::WithParamInterface<DotOfConcatTestSpec> {
+ public:
+ DotOfConcatSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that we transform
// dot(const, concat(A, B, C))
@@ -3024,7 +3034,12 @@ struct DotOfGatherTestSpec {
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {
+ public:
+ DotOfGatherSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// input: dot(DS(ctA), ctB))
// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
index 38f1a5d3a6..b342acb025 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -24,7 +24,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class BatchDotSimplificationTest : public HloVerifiedTestBase {};
+class BatchDotSimplificationTest : public HloVerifiedTestBase {
+ public:
+ BatchDotSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 49ae5320b0..b08705d4c2 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase {
StatusOr<bool> result = normalization.Run(module);
EXPECT_IS_OK(result.status());
- HloVerifier verifier(/*allow_mixed_precision=*/true);
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true);
EXPECT_IS_OK(verifier.Run(module).status());
return result.ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index c43a31b167..6c477da038 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
public:
+ ConditionalSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
};
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index f8500e78b6..e01fecffd0 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -88,6 +88,7 @@ cc_library(
":simple_orc_jit",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ ":target_machine_features",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ef71376cfb..279aa42fe2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -234,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
-} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine) {
- LLVMTargetMachineFeatures target_machine_features(target_machine);
+} // namespace
- // Optimization pipeline.
- HloPassPipeline pipeline("CPU");
- pipeline.AddInvariantChecker<HloVerifier>();
+Status CpuCompiler::RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes through layout assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
@@ -259,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
- pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
+ pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
@@ -290,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
}
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
- [&target_machine_features](
- const HloInstruction& dot,
+ [&](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
- return PotentiallyImplementedAsEigenDot(dot, target_machine_features)
+ return PotentiallyImplementedAsEigenDot(dot, *target_machine_features)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
@@ -308,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), &target_machine_features);
+ module->mutable_entry_computation_layout(), target_machine_features);
+ return pipeline.Run(module).status();
+}
+
+Status CpuCompiler::RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes after layout assignment");
+ // After layout assignment, use a layout-sensitive verifier.
+ auto& after_layout_assn =
+ pipeline.AddPass<HloPassPipeline>("after layout assignment");
+ after_layout_assn.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
+
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
- "after layout assignement");
+ "simplification after layout assignement");
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
@@ -321,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}
+
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
+
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
@@ -334,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// binary size (and most AOT applications are single-threaded).
// TODO(b/29630486) Support multi-threaded AOT.
pipeline.AddPass<ParallelTaskAssigner>(
- max_parallelism, ShapeSizeBytesFunction(), &target_machine_features);
+ max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
}
- // Copy insertion should be performed immediately before IR emission to avoid
- // inserting unnecessary copies (later pass adds an instruction which
- // materializes the value) or missing a necessary copy (later pass removes an
- // instruction which materializes a value). DCE must be run immediately before
- // (and sometime after) copy insertion, to avoid dead code from interfering
- // with the rewrites.
+ // Copy insertion should be performed immediately before IR emission to
+ // avoid inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes
+ // an instruction which materializes a value). DCE must be run immediately
+ // before (and sometime after) copy insertion, to avoid dead code from
+ // interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
@@ -349,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
return pipeline.Run(module).status();
}
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine) {
+ LLVMTargetMachineFeatures target_machine_features(target_machine);
+ TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
+ &target_machine_features));
+ return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
+ &target_machine_features);
+}
+
namespace {
// Align buffers to 16-byte boundaries.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 04e1c48872..47b5edabff 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler {
Status RunHloPasses(HloModule* module, bool is_aot_compile,
llvm::TargetMachine* target_machine);
+ // Runs HLO passes up to and including layout assignment.
+ Status RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features);
+
+ // Runs HLO passes after layout assignment.
+ Status RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features);
+
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index 82c276d7ef..a84ee78b19 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -35,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index e727ba49cb..37d1895d41 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
+ public:
+ DefuserTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Returns the number of fusion instructions in the module.
int FusionCount() {
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index d9369f00cc..136b8e19aa 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -141,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
@@ -157,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
@@ -204,7 +206,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
@@ -219,9 +222,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
}
{
- HloPassPipeline pipeline("layout_assignment");
+ // Run layout assignment in a separate pipeline from
+ // "post-layout-assignment" because we want everything after layout
+ // assignment to have a layout-sensitive invariant-checker, but
+ // HloPassPipeline also runs its invariant checker before any passes are
+ // run, meaning, the pipeline that contains layout assignment cannot contain
+ // a layout-sensitive verifier!
+ HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), stream_exec);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("post-layout_assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -267,7 +283,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>();
+ fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
@@ -277,7 +294,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
- reduce_pipeline.AddInvariantChecker<HloVerifier>();
+ reduce_pipeline.AddInvariantChecker<HloVerifier>(
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -303,7 +321,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 99e7580b82..104af48c82 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,7 +29,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-using PadForTensorCoresTest = HloVerifiedTestBase;
+class PadForTensorCoresTest : public HloVerifiedTestBase {
+ public:
+ PadForTensorCoresTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index cca35316f0..15d1e269cc 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -27,13 +27,22 @@ namespace {
class GpuKernelTilingTest : public GpuCodegenTest {
protected:
- GpuKernelTilingTest() {
+ GpuKernelTilingTest() {}
+
+ // Most tests in this file want to skip layout assignment, but a few need it
+ // enabled.
+ HloModuleConfig ConfigWithLayoutAssignment() {
+ return GetModuleConfigForTest();
+ }
+
+ HloModuleConfig ConfigWithoutLayoutAssignment() {
+ HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
- config_.set_debug_options(debug_options);
// Disable layout_assignment to use the preassigned layouts.
- debug_options.add_xla_disable_hlo_passes("layout_assignment");
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
+ config.set_debug_options(debug_options);
+ return config;
}
- HloModuleConfig config_;
};
TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
@@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ //
+ // We must enable layout assignment in order for this test to work correctly.
+ // AlgebraicSimplifier removes copy1; it's added back by layout assignment,
+ // which respects the module's entry computation layout. But if we don't run
+ // layout assignment...well, nobody else adds the copy back.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
})";
- // Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ // Check that a call to llvm.nvvm.barrier0 is not generated. As in
+ // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment
+ // here.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest,
})";
// Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
index 9622936306..0f2d5568ca 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
@@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) {
HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_max_kernel_unroll_factor(2);
+ // Disable layout assignment for this test. Layout assignment does not expect
+ // fusions to be present, and so it does the wrong thing.
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
config.set_debug_options(debug_options);
const char *const kMultiOutputFusionModule = R"(
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index c5f3906356..40183de96e 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier;
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 648228b825..79e78ee2d0 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -29,6 +29,11 @@ namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
+ public:
+ HloDomainTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 139d70374f..7e85df53a3 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -52,7 +52,10 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest() : use_bfloat16_(GetParam()) {
+ HloEvaluatorTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ use_bfloat16_(GetParam()) {
evaluator_ = absl::make_unique<HloEvaluator>();
}
@@ -1216,7 +1219,12 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
-class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {
+ public:
+ HloEvaluatorPreciseReduceTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 66dd23e73f..f60c4eab42 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -123,29 +123,26 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-namespace {
-
-Status CheckIsTokenOperand(const HloInstruction* instruction,
- int64 operand_no) {
+Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
const HloInstruction* token = instruction->operand(operand_no);
if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
return InternalError(
"Expected operand %lld to be token-shaped, actual shape is "
"%s:\n%s",
- operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
+ operand_no, StringifyShape(token->shape()).c_str(),
instruction->ToString().c_str());
}
return Status::OK();
}
-Status CheckOperandAndParameter(const HloInstruction* instruction,
- int64 operand_number,
- const HloComputation* computation,
- int64 parameter_number) {
+Status ShapeVerifier::CheckOperandAndParameter(
+ const HloInstruction* instruction, int64 operand_number,
+ const HloComputation* computation, int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
- if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
+ if (!ShapesSame(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
operand->ToString().c_str(),
parameter->ToString().c_str(),
@@ -154,8 +151,6 @@ Status CheckOperandAndParameter(const HloInstruction* instruction,
return Status::OK();
}
-} // namespace
-
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
@@ -172,13 +167,12 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
- if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
- outfeed->operand(0)->shape())) {
+ if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed shape to be compatible with operand's shape %s, "
+ "Expected outfeed shape to be equal to operand's shape %s, "
"actual shape is %s:\n%s",
- ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
+ StringifyShape(outfeed->operand(0)->shape()).c_str(),
+ StringifyShape(outfeed->outfeed_shape()).c_str(),
outfeed->ToString().c_str());
}
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
@@ -259,8 +253,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
return InternalError(
"Expected sort to have to have the same dimensions for the keys and "
"the values. Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ StringifyShape(sort->operand(0)->shape()).c_str(),
+ StringifyShape(sort->operand(1)->shape()).c_str());
}
return CheckVariadicShape(sort);
}
@@ -334,7 +328,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
-Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
+ for (HloInstruction* fused_param : fusion->fused_parameters()) {
+ int64 param_no = fused_param->parameter_number();
+ if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
+ return InternalError(
+ "Shape mismatch between parameter number %lld and its operand in "
+ "%s.",
+ param_no, fusion->ToString().c_str());
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
@@ -416,12 +421,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
- if (!ShapeUtil::Compatible(conditional_shape,
- ShapeUtil::MakeShape(PRED, {}))) {
+ if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
- ShapeUtil::HumanString(conditional_shape).c_str());
+ StringifyShape(conditional_shape).c_str());
}
// The shape of kWhile should match the shape of the body computation it
// calls.
@@ -599,52 +603,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
// Check if the output shape matches the expected shape.
- bool compatible;
+ //
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
- switch (instruction->opcode()) {
- case HloOpcode::kTupleSelect:
- // TupleSelect only defines the top-level buffer, which in this case is
- // the tuple, so we cannot allow mixed precision.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- case HloOpcode::kGetTupleElement:
- case HloOpcode::kTuple:
- // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
- // precision is disallowed.
- case HloOpcode::kConstant:
- case HloOpcode::kBitcast:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kCall:
- case HloOpcode::kConditional:
- case HloOpcode::kConvert:
- case HloOpcode::kCustomCall:
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kParameter:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
- case HloOpcode::kWhile:
- // The above opcodes should match the expected shapes exactly.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- default:
- if (allow_mixed_precision_) {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
- instruction->shape(), inferred_shape);
- } else {
- compatible =
- ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- }
- }
- if (!compatible) {
+ bool equal = [&] {
+ switch (instruction->opcode()) {
+ // The opcodes below can't have implicit layout conversions, nor can they
+ // implicitly transform f32 -> bf16. Fundamentally these are either
+ // reinterpreting existing data (e.g. kBitcast) or shuffling data around
+ // without modifying it (e.g. kGetTupleElement, kTupleSelect).
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConstant:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return ShapesSame(instruction->shape(), inferred_shape);
+
+ // We allow arbitrary layout and f32->bf16 transformations on all other
+ // instructions, although this may be made more strict pending discussion
+ // in b/112709536.
+ default:
+ if (allow_mixed_precision_) {
+ return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
+ inferred_shape);
+ } else {
+ return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ }
+ }
+ }();
+ if (!equal) {
return InternalError(
- "Expected instruction to have shape compatible with %s, actual "
+ "Expected instruction to have shape equal to %s, actual "
"shape is %s:\n%s",
- ShapeUtil::HumanString(inferred_shape).c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
+ StringifyShape(inferred_shape).c_str(),
+ StringifyShape(instruction->shape()).c_str(),
instruction->ToString().c_str());
}
return Status::OK();
@@ -828,7 +831,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
// Fused parameter instructions must be numbered contiguously and match up
- // (shapes compatible) with their respective operand.
+ // (shapes equal) with their respective operand.
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
for (auto fused_param : fused_parameters) {
@@ -849,13 +852,6 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
param_no, fusion->ToString().c_str());
}
parameter_numbers[param_no] = true;
- if (!ShapeUtil::Compatible(fused_param->shape(),
- fusion->operand(param_no)->shape())) {
- return InternalError(
- "Shape mismatch between parameter number %lld and its operand in "
- "%s.",
- param_no, fusion->ToString().c_str());
- }
}
// Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
@@ -917,7 +913,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
- "Found non-compatible shapes for instruction %s.\n"
+ "Found different shapes for instruction %s.\n"
"output: %s\noperand: %s\n",
HloOpcodeString(instruction->opcode()).c_str(),
ShapeUtil::HumanString(out_shape).c_str(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 83b77c84eb..b6093d667c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -28,9 +28,9 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
- explicit ShapeVerifier() : allow_mixed_precision_(false) {}
- explicit ShapeVerifier(bool allow_mixed_precision)
- : allow_mixed_precision_(allow_mixed_precision) {}
+ explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_(allow_mixed_precision) {}
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
@@ -106,13 +106,42 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckVariadicShape(const HloInstruction* instruction);
private:
- // Return true if the shapes of the two operands have the same element type,
- // and the result shape either has the same element type as the operand
- // shapes or mixed precision is allowed and the result shape and the operand
- // shapes have floating point element types.
+ // Helpers that switch on layout_sensitive_.
+ bool ShapesSame(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::Equal(a, b)
+ : ShapeUtil::Compatible(a, b);
+ }
+ bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
+ : ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
+ }
+ string StringifyShape(const Shape& s) {
+ return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
+ : ShapeUtil::HumanString(s);
+ }
+
+ // Checks that the given operand of the given instruction is of type TOKEN.
+ Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no);
+
+ // Checks that the shape of the given operand of the given instruction matches
+ // the given parameter of the given computation.
+ Status CheckOperandAndParameter(const HloInstruction* instruction,
+ int64 operand_number,
+ const HloComputation* computation,
+ int64 parameter_number);
+
+ // Returns true if the shapes of the two operands have the same element type,
+ // and the result shape either has the same element type as the operand shapes
+ // or mixed precision is allowed and the result shape and the operand shapes
+ // have floating point element types.
bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
const Shape& result_shape);
+ // If the verifier is layout-sensitive, shapes must be equal to what's
+ // expected. Otherwise, the shapes must simply be compatible.
+ bool layout_sensitive_;
+
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
@@ -125,14 +154,10 @@ class HloVerifier : public HloPassInterface {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- // Uses standard shape inference.
- explicit HloVerifier()
- : shape_verifier_factory_(
- [] { return absl::make_unique<ShapeVerifier>(false); }) {}
-
- explicit HloVerifier(bool allow_mixed_precision)
- : shape_verifier_factory_([allow_mixed_precision] {
- return absl::make_unique<ShapeVerifier>(allow_mixed_precision);
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
+ return absl::make_unique<ShapeVerifier>(layout_sensitive,
+ allow_mixed_precision);
}) {}
// Uses custom shape verification.
@@ -142,8 +167,7 @@ class HloVerifier : public HloPassInterface {
~HloVerifier() override = default;
absl::string_view name() const override { return "verifier"; }
- // Note: always returns false (no instructions are ever modified by this
- // pass).
+ // Never returns true; no instructions are ever modified by this pass.
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index d764964f3c..70b741353d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -37,13 +37,15 @@ using ::testing::HasSubstr;
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
class HloVerifierTestAllowMixedPrecision : public HloTestBase {
public:
HloVerifierTestAllowMixedPrecision()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
TEST_F(HloVerifierTest, NullInstructionParent) {
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index f85d31d522..df88587492 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
+ public:
+ ImplicitBroadcastRemoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
ImplicitBroadcastRemover remover_;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 2d03aebc1a..c34c32f7d3 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -22,6 +22,11 @@ limitations under the License.
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
+ public:
+ IndexedArrayAnalysisTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index fab4e797f3..a395dd5333 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -34,7 +34,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-using ReshapeMoverTest = HloVerifiedTestBase;
+class ReshapeMoverTest : public HloVerifiedTestBase {
+ public:
+ ReshapeMoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 32e69c335b..e14014b961 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
+ WhileLoopInvariantCodeMotionTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 1c892ba179..cfe4104f6d 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -28,6 +28,11 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
+ public:
+ WhileLoopSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 693454bd80..93ea144438 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -86,16 +86,20 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
-HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier)
+HloTestBase::HloTestBase(bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
+ verifier_layout_sensitive,
allow_mixed_precision_in_hlo_verifier) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
+ bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier)
: test_runner_(test_platform), reference_runner_(reference_platform) {
- hlo_verifier_ =
- absl::make_unique<HloVerifier>(allow_mixed_precision_in_hlo_verifier);
+ hlo_verifier_ = absl::make_unique<HloVerifier>(
+ /*layout_sensitive=*/verifier_layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index ce39dd78d4..06bcc39741 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -85,12 +85,14 @@ class HloTestBase : public ::testing::Test {
// automatically finds another supported backend as the test backend. If the
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
- HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true);
+ HloTestBase(bool verifier_layout_sensitive = false,
+ bool allow_mixed_precision_in_hlo_verifier = true);
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
+ bool verifier_layout_sensitive = false,
bool allow_mixed_precision_in_hlo_verifier = true);
~HloTestBase() override {}
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index dd130557b3..8f86c528d0 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -25,8 +25,11 @@ limitations under the License.
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase()
- : shape_verifier_(absl::make_unique<ShapeVerifier>()) {}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
HloVerifiedTestBase::~HloVerifiedTestBase() {
// We can't call the ASSERT or EXPECT test macros in destructors, so we
@@ -51,8 +54,7 @@ void HloVerifiedTestBase::TearDown() {
}
void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- HloVerifier verifier(/*allow_mixed_precision=*/true);
- xla::StatusOr<bool> mutated = verifier.Run(module);
+ xla::StatusOr<bool> mutated = verifier().Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index a2f1185d63..cc6967feed 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -29,7 +29,8 @@ namespace xla {
// performs verification on that module on tear-down.
class HloVerifiedTestBase : public HloTestBase {
protected:
- HloVerifiedTestBase();
+ explicit HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision);
~HloVerifiedTestBase() override;
// Constructs a default shape verifier.
@@ -47,29 +48,25 @@ class HloVerifiedTestBase : public HloTestBase {
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
- // Sets the shape-size function used during hlo verification. If this isn't
- // called, a default ShapeVerifier is used instead.
- void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) {
- shape_verifier_ = std::move(shape_verifier);
- }
-
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
HloModule* CreateNewModule(const string& name = TestName());
+ private:
+ void VerifyModule(HloModule* module);
+
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
- private:
+ //
// Lazily populated. Access via module().
std::unique_ptr<HloModule> module_;
// Populated by calls to CreateNewModule.
std::vector<std::unique_ptr<HloModule>> modules_;
- std::unique_ptr<ShapeVerifier> shape_verifier_;
+
bool tear_down_called_ = false;
- static void VerifyModule(HloModule* module);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index adc0956164..16b77e965d 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -53,12 +53,22 @@ class MultiOutputFusionTest : public HloTestBase {
protected:
MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
+ // Layout assignment assumes that there are no fusions in the input graph.
+ // Since the purpose of this test is to send pre-fused graphs to XLA, we have
+ // to do layout assignment ourselves.
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.add_xla_disable_hlo_passes("layout-assignment");
+ return opts;
+ }
+
void RunTest2D(bool manual_fusion, int64 size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {});
- const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
+ const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
+ const Shape elem_shape2 =
+ ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0});
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
@@ -101,10 +111,10 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal arg1(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
arg1.PopulateWithValue<float>(2.5f);
- Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual =
ExecuteAndTransfer(std::move(hlo_module),
@@ -116,8 +126,10 @@ class MultiOutputFusionTest : public HloTestBase {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size});
- const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size});
+ const Shape elem_shape_F32 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});
+ const Shape elem_shape_U8 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F64, {size});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape_F32, "0"));
auto param1 = builder.AddInstruction(
@@ -137,12 +149,13 @@ class MultiOutputFusionTest : public HloTestBase {
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {size, 1}), add));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
+ dot_dnums));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
@@ -162,9 +175,9 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal input0(ShapeUtil::MakeShape(F32, {size}));
+ Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}));
input0.PopulateWithValue(2.5f);
- Literal input1(ShapeUtil::MakeShape(F64, {size}));
+ Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
Literal expect =
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2f1d97b25d..21c58e075e 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -408,8 +408,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
return std::move(arguments);
}
-Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) {
- return HloVerifier(allow_mixed_precision).Run(module).status();
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision) {
+ return HloVerifier(/*layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision)
+ .Run(module)
+ .status();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 1aca1d8ef7..277d53d423 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -95,8 +95,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
// Check that a given module satisfies various constraints before trying to
// execute it.
-Status VerifyHloModule(HloModule* const module,
- bool allow_mixed_precision = false);
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision);
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index 870da6efa8..c7eb9e2dbe 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -66,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -83,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
"param"));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -100,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr(