diff options
author | 2017-12-07 17:46:37 -0800 | |
---|---|---|
committer | 2017-12-07 17:50:07 -0800 | |
commit | 0e9cc7f3113ade82436729bd541f6b501d023ac0 (patch) | |
tree | 797d2a0867bba92008d93d9f6cc416bb3b9f8e57 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc | |
parent | 1667d4dcd2c7c33a3bcade62014931a1f8d9a2e0 (diff) |
[XLA] Implement Conditional in XLA service, client ComputationBuilder, and CPU backend.
PiperOrigin-RevId: 178322445
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc | 322 |
1 files changed, 322 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f08f0b1d68..e714b2567f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is @@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase, analysis_->GetValueDefinedAt(b), *analysis_); } + std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation( + HloOpcode opcode) { + HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); + return builder.Build(); + } + std::unique_ptr<HloModule> module_; std::unique_ptr<HloDataflowAnalysis> analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { @@ -1528,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } +TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { + // Test conditional with identity computations in both true and false cases. + // + // true_computation(F32[] %true_param): + // return %true_param + // + // false_computation(F32[] %false_param): + // return %false_param + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // return Conditional(%pred, %constant1, true_computation, + // %constant2, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, constant1, true_computation, constant2, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + ElementsAre(HloUse{conditional, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + ElementsAre(HloUse{conditional, 2, {}})); + + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); +} + +TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { + // Test conditional with true and false computations taking a tuple operand. + // + // true_computation((F32[], F32[]) %true_param): + // %true_x = GetTupleElement(%true_param, 0) + // %true_y = GetTupleElement(%true_param, 1) + // return Add(%true_x, %true_y) + // + // false_computation((F32[], F32[]) %false_param): + // %false_x = GetTupleElement(%false_param, 0) + // %false_y = GetTupleElement(%false_param, 1) + // return Subtract(%false_x, %false_y) + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // %tuple_operand = Tuple(%constant1, %constant2) + // return Conditional(%pred, %tuple_operand, true_computation, + // %tuple_operand, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); + auto true_x = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); + auto true_y = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, true_x, true_y)); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); + auto false_x = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); + auto false_y = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); + auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_y), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_y), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {0}}, + HloUse{conditional, 2, {0}}, + HloUse{add, 0, {}}, HloUse{sub, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {1}}, + HloUse{conditional, 2, {1}}, + HloUse{add, 1, {}}, HloUse{sub, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), + UnorderedElementsAre( + HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, + HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, + HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); + + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); +} + +TEST_P(HloDataflowAnalysisTest, NestedConditionals) { + // computation1(F32[] %param1): + // %ceil = Ceil(%param1) + // return %ceil + // + // computation2(F32[] %param2): + // %floor = Floor(%param2) + // return %floor + // + // computation3(F32[] %param3): + // %negate = Negate(%param3) + // return %negate + // + // inner_conditional((PRED, F32[], F32[]) %param_cond): + // %pred_cond = GetTupleElement(%param_cond, 0) + // %true_operand_cond = GetTupleElement(%param_cond, 1) + // %false_opearnd_cond = GetTupleElement(%param_cond, 2) + // return Conditional(%pred_cond, %true_operand_cond, computation1, + // %false_operand_cond, computation2) + // + // entry: + // %pred1 = Constant(true) + // %pred2 = Constant(false) + // %constant1 = Constant(1.1); + // %constant2 = Constant(2.2); + // %constant3 = Constant(3.3); + // return Conditional(%pred1, (%pred2, %constant1, %constant2), + // inner_conditional, %constant3, computation3) + + auto computation1 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); + auto computation2 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); + auto computation3 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); + + // Build inner_conditional computation. + const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {scalar_bool_shape, scalar_shape_, scalar_shape_}); + auto inner_builder = + HloComputation::Builder(TestName() + "_inner_conditional"); + auto param_cond = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); + auto pred_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); + auto true_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); + auto false_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); + auto inner_conditional = + inner_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred_cond, true_operand_cond, computation1, + false_operand_cond, computation2)); + auto inner_conditional_computation = + module_->AddEmbeddedComputation(inner_builder.Build()); + + // Build entry computation. + auto builder = HloComputation::Builder(TestName()); + auto pred1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))); + auto pred2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(2.2f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(3.3f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({pred2, constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred1, tuple_operand, inner_conditional_computation, + constant3, computation3)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); + + auto computation1_param = computation1->parameter_instruction(0); + auto computation2_param = computation2->parameter_instruction(0); + auto computation3_param = computation3->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), + analysis.GetValueDefinedAt(constant3)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); + EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), + analysis.GetValueDefinedAt(pred2)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); |