diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cse_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_cse_test.cc | 97 |
1 files changed, 66 insertions, 31 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 16db374566..76b9c66651 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -53,9 +53,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(42.0f, constant->literal().Get<float>({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR0<float>(84.0); + auto expected = LiteralUtil::CreateR0<float>(84.0); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -81,10 +81,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -113,10 +113,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -144,20 +144,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { auto builder = HloComputation::Builder(TestName()); std::vector<HloInstruction*> constants; constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<uint64>(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<int64>(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<double>(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)))); // Duplicate the float constant to verify something happens. constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)))); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); for (int64 i = 0; i < constants.size(); ++i) { @@ -188,13 +188,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}))); + LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -223,7 +223,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -253,7 +253,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -284,7 +284,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -362,7 +362,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -400,9 +400,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -442,9 +442,9 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -459,7 +459,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1<float>({5.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( @@ -521,9 +521,9 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { // in this case) are not collapsed. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -536,5 +536,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); } +TEST_F(HloCseTest, Domain) { + auto module = ParseHloString(R"( +HloModule module +ENTRY %entry { + %param = f32[] parameter(0), sharding={maximal device=0} + %domain.0 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.1 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.2 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}} + %negate.0 = f32[] negate(%domain.0) + %negate.1 = f32[] negate(%domain.1) + %negate.2 = f32[] negate(%domain.2) + %domain.3 = f32[] domain(%negate.0), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.4 = f32[] domain(%negate.1), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.5 = f32[] domain(%negate.2), + domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} + %add = f32[] add(%domain.3, %domain.4) + ROOT %sub = f32[] subtract(%add, %domain.5) +})") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + LOG(INFO) << "AAAAA " << module->ToString(); + const HloInstruction* sub = module->entry_computation()->root_instruction(); + const HloInstruction* add = sub->operand(0); + EXPECT_EQ(add->operand(0), add->operand(1)); + EXPECT_NE(add->operand(0), sub->operand(1)); + EXPECT_NE(add->operand(1), sub->operand(1)); +} + } // namespace } // namespace xla |