aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc69
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc144
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h58
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h4
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h17
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc33
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc8
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h4
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc15
32 files changed, 387 insertions, 177 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index ebf21ac151..bb63ea26d4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -52,7 +52,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
return [](const Shape&, const Shape&) { return false; };
}
-class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
+class AlgebraicSimplifierTest : public HloVerifiedTestBase {
+ public:
+ AlgebraicSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that A + 0 is simplified to A
TEST_F(AlgebraicSimplifierTest, AddZero) {
@@ -2851,7 +2856,12 @@ struct DotOfConcatTestSpec {
class DotOfConcatSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
+ public ::testing::WithParamInterface<DotOfConcatTestSpec> {
+ public:
+ DotOfConcatSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Test that we transform
// dot(const, concat(A, B, C))
@@ -3024,7 +3034,12 @@ struct DotOfGatherTestSpec {
class DotOfGatherSimplificationTest
: public HloVerifiedTestBase,
- public ::testing::WithParamInterface<DotOfGatherTestSpec> {};
+ public ::testing::WithParamInterface<DotOfGatherTestSpec> {
+ public:
+ DotOfGatherSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// input: dot(DS(ctA), ctB))
// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}.
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
index 38f1a5d3a6..b342acb025 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc
@@ -24,7 +24,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-class BatchDotSimplificationTest : public HloVerifiedTestBase {};
+class BatchDotSimplificationTest : public HloVerifiedTestBase {
+ public:
+ BatchDotSimplificationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(BatchDotSimplificationTest,
ElideSingleDegenerateBatchDotDim_VectorVector) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 49ae5320b0..b08705d4c2 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase {
StatusOr<bool> result = normalization.Run(module);
EXPECT_IS_OK(result.status());
- HloVerifier verifier(/*allow_mixed_precision=*/true);
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true);
EXPECT_IS_OK(verifier.Run(module).status());
return result.ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index c43a31b167..6c477da038 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers;
class ConditionalSimplifierTest : public HloVerifiedTestBase {
public:
+ ConditionalSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation that contains a conditional with constant predicate.
HloComputation* MakeConditional(HloModule* module);
};
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index f8500e78b6..e01fecffd0 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -88,6 +88,7 @@ cc_library(
":simple_orc_jit",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ ":target_machine_features",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index ef71376cfb..279aa42fe2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -234,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
-} // namespace
-Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
- llvm::TargetMachine* target_machine) {
- LLVMTargetMachineFeatures target_machine_features(target_machine);
+} // namespace
- // Optimization pipeline.
- HloPassPipeline pipeline("CPU");
- pipeline.AddInvariantChecker<HloVerifier>();
+Status CpuCompiler::RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes through layout assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<CpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
@@ -259,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
- pipeline.AddPass<ConvCanonicalization>(&target_machine_features);
+ pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
@@ -290,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
}
pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
pipeline.AddPass<TransposeFolding>(
- [&target_machine_features](
- const HloInstruction& dot,
+ [&](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
- return PotentiallyImplementedAsEigenDot(dot, target_machine_features)
+ return PotentiallyImplementedAsEigenDot(dot, *target_machine_features)
? candidate_operands
: TransposeFolding::OperandIndices{};
},
@@ -308,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_entry_computation_layout(), &target_machine_features);
+ module->mutable_entry_computation_layout(), target_machine_features);
+ return pipeline.Run(module).status();
+}
+
+Status CpuCompiler::RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features) {
+ HloPassPipeline pipeline("HLO passes after layout assignment");
+ // After layout assignment, use a layout-sensitive verifier.
+ auto& after_layout_assn =
+ pipeline.AddPass<HloPassPipeline>("after layout assignment");
+ after_layout_assn.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
+
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
- "after layout assignement");
+ "simplification after layout assignement");
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
@@ -321,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<HloDCE>();
pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
}
+
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
+
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
module->config().intra_op_parallelism_threads() > 0
@@ -334,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
// binary size (and most AOT applications are single-threaded).
// TODO(b/29630486) Support multi-threaded AOT.
pipeline.AddPass<ParallelTaskAssigner>(
- max_parallelism, ShapeSizeBytesFunction(), &target_machine_features);
+ max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
}
- // Copy insertion should be performed immediately before IR emission to avoid
- // inserting unnecessary copies (later pass adds an instruction which
- // materializes the value) or missing a necessary copy (later pass removes an
- // instruction which materializes a value). DCE must be run immediately before
- // (and sometime after) copy insertion, to avoid dead code from interfering
- // with the rewrites.
+ // Copy insertion should be performed immediately before IR emission to
+ // avoid inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes
+ // an instruction which materializes a value). DCE must be run immediately
+ // before (and sometime after) copy insertion, to avoid dead code from
+ // interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
@@ -349,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
return pipeline.Run(module).status();
}
+Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
+ llvm::TargetMachine* target_machine) {
+ LLVMTargetMachineFeatures target_machine_features(target_machine);
+ TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
+ &target_machine_features));
+ return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
+ &target_machine_features);
+}
+
namespace {
// Align buffers to 16-byte boundaries.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index 04e1c48872..47b5edabff 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler {
Status RunHloPasses(HloModule* module, bool is_aot_compile,
llvm::TargetMachine* target_machine);
+ // Runs HLO passes up to and including layout assignment.
+ Status RunHloPassesThroughLayoutAssn(
+ HloModule* module, bool /*is_aot_compile*/,
+ LLVMTargetMachineFeatures* target_machine_features);
+
+ // Runs HLO passes after layout assignment.
+ Status RunHloPassesAfterLayoutAssn(
+ HloModule* module, bool is_aot_compile,
+ LLVMTargetMachineFeatures* target_machine_features);
+
TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
};
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index 82c276d7ef..a84ee78b19 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -35,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index e727ba49cb..37d1895d41 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class DefuserTest : public HloVerifiedTestBase {
+ public:
+ DefuserTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Returns the number of fusion instructions in the module.
int FusionCount() {
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index d9369f00cc..136b8e19aa 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -141,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
pipeline.AddPass<GpuHloSupportChecker>();
ReducePrecisionInsertion::AddPasses(
&pipeline, hlo_module->config().debug_options(),
@@ -157,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
- pass.AddInvariantChecker<HloVerifier>();
+ pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
@@ -204,7 +206,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
// TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
@@ -219,9 +222,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
}
{
- HloPassPipeline pipeline("layout_assignment");
+ // Run layout assignment in a separate pipeline from
+ // "post-layout-assignment" because we want everything after layout
+ // assignment to have a layout-sensitive invariant-checker, but
+ // HloPassPipeline also runs its invariant checker before any passes are
+ // run, meaning, the pipeline that contains layout assignment cannot contain
+ // a layout-sensitive verifier!
+ HloPassPipeline pipeline("layout assignment");
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(), stream_exec);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("post-layout_assignment");
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -267,7 +283,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>();
+ fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
@@ -277,7 +294,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
- reduce_pipeline.AddInvariantChecker<HloVerifier>();
+ reduce_pipeline.AddInvariantChecker<HloVerifier>(
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -303,7 +321,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>();
+ pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
index 99e7580b82..104af48c82 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc
@@ -29,7 +29,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
-using PadForTensorCoresTest = HloVerifiedTestBase;
+class PadForTensorCoresTest : public HloVerifiedTestBase {
+ public:
+ PadForTensorCoresTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
ParseAndVerifyModule(R"(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
index cca35316f0..15d1e269cc 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
@@ -27,13 +27,22 @@ namespace {
class GpuKernelTilingTest : public GpuCodegenTest {
protected:
- GpuKernelTilingTest() {
+ GpuKernelTilingTest() {}
+
+ // Most tests in this file want to skip layout assignment, but a few need it
+ // enabled.
+ HloModuleConfig ConfigWithLayoutAssignment() {
+ return GetModuleConfigForTest();
+ }
+
+ HloModuleConfig ConfigWithoutLayoutAssignment() {
+ HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
- config_.set_debug_options(debug_options);
// Disable layout_assignment to use the preassigned layouts.
- debug_options.add_xla_disable_hlo_passes("layout_assignment");
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
+ config.set_debug_options(debug_options);
+ return config;
}
- HloModuleConfig config_;
};
TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
@@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ //
+ // We must enable layout assignment in order for this test to work correctly.
+ // AlgebraicSimplifier removes copy1; it's added back by layout assignment,
+ // which respects the module's entry computation layout. But if we don't run
+ // layout assignment...well, nobody else adds the copy back.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
})";
- // Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ // Check that a call to llvm.nvvm.barrier0 is not generated. As in
+ // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment
+ // here.
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @copy
@@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
})";
// Check that a call to llvm.nvvm.barrier0 is generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
@@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest,
})";
// Check that a call to llvm.nvvm.barrier0 is not generated.
- auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
+ auto hlo_module =
+ ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: define void @fusion
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
index 9622936306..0f2d5568ca 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc
@@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) {
HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
debug_options.set_xla_gpu_max_kernel_unroll_factor(2);
+ // Disable layout assignment for this test. Layout assignment does not expect
+ // fusions to be present, and so it does the wrong thing.
+ debug_options.add_xla_disable_hlo_passes("layout-assignment");
config.set_debug_options(debug_options);
const char *const kMultiOutputFusionModule = R"(
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index c5f3906356..40183de96e 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier;
+ HloVerifier verifier(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false);
TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 648228b825..79e78ee2d0 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -29,6 +29,11 @@ namespace xla {
namespace {
class HloDomainTest : public HloVerifiedTestBase {
+ public:
+ HloDomainTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 139d70374f..7e85df53a3 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -52,7 +52,10 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest() : use_bfloat16_(GetParam()) {
+ HloEvaluatorTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false),
+ use_bfloat16_(GetParam()) {
evaluator_ = absl::make_unique<HloEvaluator>();
}
@@ -1216,7 +1219,12 @@ TEST_P(HloEvaluatorTest,
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
-class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
+class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {
+ public:
+ HloEvaluatorPreciseReduceTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
// Tests that Reduce doesn't lose precision when adding many numbers (because
// it accumulates its result in a double).
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 66dd23e73f..f60c4eab42 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -123,29 +123,26 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-namespace {
-
-Status CheckIsTokenOperand(const HloInstruction* instruction,
- int64 operand_no) {
+Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
const HloInstruction* token = instruction->operand(operand_no);
if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
return InternalError(
"Expected operand %lld to be token-shaped, actual shape is "
"%s:\n%s",
- operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
+ operand_no, StringifyShape(token->shape()).c_str(),
instruction->ToString().c_str());
}
return Status::OK();
}
-Status CheckOperandAndParameter(const HloInstruction* instruction,
- int64 operand_number,
- const HloComputation* computation,
- int64 parameter_number) {
+Status ShapeVerifier::CheckOperandAndParameter(
+ const HloInstruction* instruction, int64 operand_number,
+ const HloComputation* computation, int64 parameter_number) {
const HloInstruction* operand = instruction->operand(operand_number);
const HloInstruction* parameter =
computation->parameter_instruction(parameter_number);
- if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) {
+ if (!ShapesSame(operand->shape(), parameter->shape())) {
return InternalError("Operand %s shape does not match parameter's %s in %s",
operand->ToString().c_str(),
parameter->ToString().c_str(),
@@ -154,8 +151,6 @@ Status CheckOperandAndParameter(const HloInstruction* instruction,
return Status::OK();
}
-} // namespace
-
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
@@ -172,13 +167,12 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
// Outfeed has a separate shape field for the value which is outfed to the
// host. The shape of the instruction itself is always a token.
- if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
- outfeed->operand(0)->shape())) {
+ if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed shape to be compatible with operand's shape %s, "
+ "Expected outfeed shape to be equal to operand's shape %s, "
"actual shape is %s:\n%s",
- ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
+ StringifyShape(outfeed->operand(0)->shape()).c_str(),
+ StringifyShape(outfeed->outfeed_shape()).c_str(),
outfeed->ToString().c_str());
}
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
@@ -259,8 +253,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
return InternalError(
"Expected sort to have to have the same dimensions for the keys and "
"the values. Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
- ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ StringifyShape(sort->operand(0)->shape()).c_str(),
+ StringifyShape(sort->operand(1)->shape()).c_str());
}
return CheckVariadicShape(sort);
}
@@ -334,7 +328,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
-Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
+ for (HloInstruction* fused_param : fusion->fused_parameters()) {
+ int64 param_no = fused_param->parameter_number();
+ if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
+ return InternalError(
+ "Shape mismatch between parameter number %lld and its operand in "
+ "%s.",
+ param_no, fusion->ToString().c_str());
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleCall(HloInstruction* call) {
for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
@@ -416,12 +421,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
const Shape& conditional_shape =
xla_while->while_condition()->root_instruction()->shape();
- if (!ShapeUtil::Compatible(conditional_shape,
- ShapeUtil::MakeShape(PRED, {}))) {
+ if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
return InternalError(
"Conditional computation shape does not lead to a scalar predicate "
"shape: %s",
- ShapeUtil::HumanString(conditional_shape).c_str());
+ StringifyShape(conditional_shape).c_str());
}
// The shape of kWhile should match the shape of the body computation it
// calls.
@@ -599,52 +603,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
// Check if the output shape matches the expected shape.
- bool compatible;
+ //
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
- switch (instruction->opcode()) {
- case HloOpcode::kTupleSelect:
- // TupleSelect only defines the top-level buffer, which in this case is
- // the tuple, so we cannot allow mixed precision.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- case HloOpcode::kGetTupleElement:
- case HloOpcode::kTuple:
- // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
- // precision is disallowed.
- case HloOpcode::kConstant:
- case HloOpcode::kBitcast:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kCall:
- case HloOpcode::kConditional:
- case HloOpcode::kConvert:
- case HloOpcode::kCustomCall:
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kParameter:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
- case HloOpcode::kSend:
- case HloOpcode::kSendDone:
- case HloOpcode::kWhile:
- // The above opcodes should match the expected shapes exactly.
- compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- break;
- default:
- if (allow_mixed_precision_) {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
- instruction->shape(), inferred_shape);
- } else {
- compatible =
- ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- }
- }
- if (!compatible) {
+ bool equal = [&] {
+ switch (instruction->opcode()) {
+ // The opcodes below can't have implicit layout conversions, nor can they
+ // implicitly transform f32 -> bf16. Fundamentally these are either
+ // reinterpreting existing data (e.g. kBitcast) or shuffling data around
+ // without modifying it (e.g. kGetTupleElement, kTupleSelect).
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCall:
+ case HloOpcode::kConditional:
+ case HloOpcode::kConstant:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
+ case HloOpcode::kWhile:
+ return ShapesSame(instruction->shape(), inferred_shape);
+
+ // We allow arbitrary layout and f32->bf16 transformations on all other
+ // instructions, although this may be made more strict pending discussion
+ // in b/112709536.
+ default:
+ if (allow_mixed_precision_) {
+ return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
+ inferred_shape);
+ } else {
+ return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
+ }
+ }
+ }();
+ if (!equal) {
return InternalError(
- "Expected instruction to have shape compatible with %s, actual "
+ "Expected instruction to have shape equal to %s, actual "
"shape is %s:\n%s",
- ShapeUtil::HumanString(inferred_shape).c_str(),
- ShapeUtil::HumanString(instruction->shape()).c_str(),
+ StringifyShape(inferred_shape).c_str(),
+ StringifyShape(instruction->shape()).c_str(),
instruction->ToString().c_str());
}
return Status::OK();
@@ -828,7 +831,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
// Fused parameter instructions must be numbered contiguously and match up
- // (shapes compatible) with their respective operand.
+ // (shapes equal) with their respective operand.
CHECK_EQ(fusion->operands().size(), fused_parameters.size());
std::vector<bool> parameter_numbers(fused_parameters.size(), false);
for (auto fused_param : fused_parameters) {
@@ -849,13 +852,6 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
param_no, fusion->ToString().c_str());
}
parameter_numbers[param_no] = true;
- if (!ShapeUtil::Compatible(fused_param->shape(),
- fusion->operand(param_no)->shape())) {
- return InternalError(
- "Shape mismatch between parameter number %lld and its operand in "
- "%s.",
- param_no, fusion->ToString().c_str());
- }
}
// Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
@@ -917,7 +913,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
- "Found non-compatible shapes for instruction %s.\n"
+ "Found different shapes for instruction %s.\n"
"output: %s\noperand: %s\n",
HloOpcodeString(instruction->opcode()).c_str(),
ShapeUtil::HumanString(out_shape).c_str(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 83b77c84eb..b6093d667c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -28,9 +28,9 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
- explicit ShapeVerifier() : allow_mixed_precision_(false) {}
- explicit ShapeVerifier(bool allow_mixed_precision)
- : allow_mixed_precision_(allow_mixed_precision) {}
+ explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_(allow_mixed_precision) {}
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
@@ -106,13 +106,42 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckVariadicShape(const HloInstruction* instruction);
private:
- // Return true if the shapes of the two operands have the same element type,
- // and the result shape either has the same element type as the operand
- // shapes or mixed precision is allowed and the result shape and the operand
- // shapes have floating point element types.
+ // Helpers that switch on layout_sensitive_.
+ bool ShapesSame(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::Equal(a, b)
+ : ShapeUtil::Compatible(a, b);
+ }
+ bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
+ return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
+ : ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
+ }
+ string StringifyShape(const Shape& s) {
+ return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
+ : ShapeUtil::HumanString(s);
+ }
+
+ // Checks that the given operand of the given instruction is of type TOKEN.
+ Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no);
+
+ // Checks that the shape of the given operand of the given instruction matches
+ // the given parameter of the given computation.
+ Status CheckOperandAndParameter(const HloInstruction* instruction,
+ int64 operand_number,
+ const HloComputation* computation,
+ int64 parameter_number);
+
+ // Returns true if the shapes of the two operands have the same element type,
+ // and the result shape either has the same element type as the operand shapes
+ // or mixed precision is allowed and the result shape and the operand shapes
+ // have floating point element types.
bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
const Shape& result_shape);
+ // If the verifier is layout-sensitive, shapes must be equal to what's
+ // expected. Otherwise, the shapes must simply be compatible.
+ bool layout_sensitive_;
+
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
@@ -125,14 +154,10 @@ class HloVerifier : public HloPassInterface {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- // Uses standard shape inference.
- explicit HloVerifier()
- : shape_verifier_factory_(
- [] { return absl::make_unique<ShapeVerifier>(false); }) {}
-
- explicit HloVerifier(bool allow_mixed_precision)
- : shape_verifier_factory_([allow_mixed_precision] {
- return absl::make_unique<ShapeVerifier>(allow_mixed_precision);
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
+ return absl::make_unique<ShapeVerifier>(layout_sensitive,
+ allow_mixed_precision);
}) {}
// Uses custom shape verification.
@@ -142,8 +167,7 @@ class HloVerifier : public HloPassInterface {
~HloVerifier() override = default;
absl::string_view name() const override { return "verifier"; }
- // Note: always returns false (no instructions are ever modified by this
- // pass).
+ // Never returns true; no instructions are ever modified by this pass.
StatusOr<bool> Run(HloModule* module) override;
private:
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index d764964f3c..70b741353d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -37,13 +37,15 @@ using ::testing::HasSubstr;
class HloVerifierTest : public HloTestBase {
public:
HloVerifierTest()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/false) {}
};
class HloVerifierTestAllowMixedPrecision : public HloTestBase {
public:
HloVerifierTestAllowMixedPrecision()
- : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {}
+ : HloTestBase(/*verifier_layout_sensitive=*/false,
+ /*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
TEST_F(HloVerifierTest, NullInstructionParent) {
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index f85d31d522..df88587492 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -26,6 +26,11 @@ namespace xla {
namespace {
class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase {
+ public:
+ ImplicitBroadcastRemoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
ImplicitBroadcastRemover remover_;
};
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 2d03aebc1a..c34c32f7d3 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -22,6 +22,11 @@ limitations under the License.
namespace xla {
namespace {
class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
+ public:
+ IndexedArrayAnalysisTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
void AssertArrayForRootExpressionIs(const string& hlo_text,
const string& root_expression) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index fab4e797f3..a395dd5333 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -34,7 +34,12 @@ namespace {
namespace op = xla::testing::opcode_matchers;
-using ReshapeMoverTest = HloVerifiedTestBase;
+class ReshapeMoverTest : public HloVerifiedTestBase {
+ public:
+ ReshapeMoverTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+};
TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 32e69c335b..e14014b961 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers;
class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase {
public:
+ WhileLoopInvariantCodeMotionTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
// Makes a computation which has one parameter, of the given shape, and always
// returns PRED[]{true}. This is useful as a dummy loop condition.
HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape,
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 1c892ba179..cfe4104f6d 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -28,6 +28,11 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class WhileLoopSimplifierTest : public HloVerifiedTestBase {
+ public:
+ WhileLoopSimplifierTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/false) {}
+
protected:
// Makes an HloModule that contains a loop with `num_iters` iteration.
void MakeModuleWithSimpleLoop(int num_iters);
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 693454bd80..93ea144438 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -86,16 +86,20 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
-HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier)
+HloTestBase::HloTestBase(bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
+ verifier_layout_sensitive,
allow_mixed_precision_in_hlo_verifier) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
+ bool verifier_layout_sensitive,
bool allow_mixed_precision_in_hlo_verifier)
: test_runner_(test_platform), reference_runner_(reference_platform) {
- hlo_verifier_ =
- absl::make_unique<HloVerifier>(allow_mixed_precision_in_hlo_verifier);
+ hlo_verifier_ = absl::make_unique<HloVerifier>(
+ /*layout_sensitive=*/verifier_layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index ce39dd78d4..06bcc39741 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -85,12 +85,14 @@ class HloTestBase : public ::testing::Test {
// automatically finds another supported backend as the test backend. If the
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
- HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true);
+ HloTestBase(bool verifier_layout_sensitive = false,
+ bool allow_mixed_precision_in_hlo_verifier = true);
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
+ bool verifier_layout_sensitive = false,
bool allow_mixed_precision_in_hlo_verifier = true);
~HloTestBase() override {}
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index dd130557b3..8f86c528d0 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -25,8 +25,11 @@ limitations under the License.
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase()
- : shape_verifier_(absl::make_unique<ShapeVerifier>()) {}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
HloVerifiedTestBase::~HloVerifiedTestBase() {
// We can't call the ASSERT or EXPECT test macros in destructors, so we
@@ -51,8 +54,7 @@ void HloVerifiedTestBase::TearDown() {
}
void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- HloVerifier verifier(/*allow_mixed_precision=*/true);
- xla::StatusOr<bool> mutated = verifier.Run(module);
+ xla::StatusOr<bool> mutated = verifier().Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index a2f1185d63..cc6967feed 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -29,7 +29,8 @@ namespace xla {
// performs verification on that module on tear-down.
class HloVerifiedTestBase : public HloTestBase {
protected:
- HloVerifiedTestBase();
+ explicit HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision);
~HloVerifiedTestBase() override;
// Constructs a default shape verifier.
@@ -47,29 +48,25 @@ class HloVerifiedTestBase : public HloTestBase {
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
- // Sets the shape-size function used during hlo verification. If this isn't
- // called, a default ShapeVerifier is used instead.
- void SetShapeVerifier(std::unique_ptr<ShapeVerifier> shape_verifier) {
- shape_verifier_ = std::move(shape_verifier);
- }
-
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
HloModule* CreateNewModule(const string& name = TestName());
+ private:
+ void VerifyModule(HloModule* module);
+
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
- private:
+ //
// Lazily populated. Access via module().
std::unique_ptr<HloModule> module_;
// Populated by calls to CreateNewModule.
std::vector<std::unique_ptr<HloModule>> modules_;
- std::unique_ptr<ShapeVerifier> shape_verifier_;
+
bool tear_down_called_ = false;
- static void VerifyModule(HloModule* module);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index adc0956164..16b77e965d 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -53,12 +53,22 @@ class MultiOutputFusionTest : public HloTestBase {
protected:
MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
+ // Layout assignment assumes that there are no fusions in the input graph.
+ // Since the purpose of this test is to send pre-fused graphs to XLA, we have
+ // to do layout assignment ourselves.
+ DebugOptions GetDebugOptionsForTest() override {
+ auto opts = HloTestBase::GetDebugOptionsForTest();
+ opts.add_xla_disable_hlo_passes("layout-assignment");
+ return opts;
+ }
+
void RunTest2D(bool manual_fusion, int64 size) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {});
- const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
+ const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
+ const Shape elem_shape2 =
+ ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0});
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
@@ -101,10 +111,10 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal arg1(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
arg1.PopulateWithValue<float>(2.5f);
- Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
+ Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual =
ExecuteAndTransfer(std::move(hlo_module),
@@ -116,8 +126,10 @@ class MultiOutputFusionTest : public HloTestBase {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size});
- const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size});
+ const Shape elem_shape_F32 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size});
+ const Shape elem_shape_U8 =
+ ShapeUtil::MakeShapeWithDescendingLayout(F64, {size});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape_F32, "0"));
auto param1 = builder.AddInstruction(
@@ -137,12 +149,13 @@ class MultiOutputFusionTest : public HloTestBase {
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {size, 1}), add));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
- ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums));
+ ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
+ dot_dnums));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
@@ -162,9 +175,9 @@ class MultiOutputFusionTest : public HloTestBase {
nullptr);
}
- Literal input0(ShapeUtil::MakeShape(F32, {size}));
+ Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}));
input0.PopulateWithValue(2.5f);
- Literal input1(ShapeUtil::MakeShape(F64, {size}));
+ Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
input1.PopulateWithValue(1.);
Literal expect =
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2f1d97b25d..21c58e075e 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -408,8 +408,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
return std::move(arguments);
}
-Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) {
- return HloVerifier(allow_mixed_precision).Run(module).status();
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision) {
+ return HloVerifier(/*layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision=*/allow_mixed_precision)
+ .Run(module)
+ .status();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 1aca1d8ef7..277d53d423 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -95,8 +95,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
// Check that a given module satisfies various constraints before trying to
// execute it.
-Status VerifyHloModule(HloModule* const module,
- bool allow_mixed_precision = false);
+Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
+ bool allow_mixed_precision);
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index 870da6efa8..c7eb9e2dbe 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -66,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -83,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
"param"));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(
status.error_message(),
@@ -100,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
module->AddEntryComputation(builder.Build());
- Status status = HloVerifier().Run(module.get()).status();
+ Status status =
+ HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr(