aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-07 17:46:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 17:50:07 -0800
commit0e9cc7f3113ade82436729bd541f6b501d023ac0 (patch)
tree797d2a0867bba92008d93d9f6cc416bb3b9f8e57 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent1667d4dcd2c7c33a3bcade62014931a1f8d9a2e0 (diff)
[XLA] Implement Conditional in XLA service, client ComputationBuilder, and CPU backend.
PiperOrigin-RevId: 178322445
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc322
1 files changed, 322 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index f08f0b1d68..e714b2567f 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -34,6 +34,7 @@ limitations under the License.
namespace xla {
namespace {
+using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;
// Test is parameterized on a bool which is whether the dataflow analysis is
@@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase,
analysis_->GetValueDefinedAt(b), *analysis_);
}
+ std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
+ HloOpcode opcode) {
+ HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
+ return builder.Build();
+ }
+
std::unique_ptr<HloModule> module_;
std::unique_ptr<HloDataflowAnalysis> analysis_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
+ const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
};
TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
@@ -1528,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
}
+TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
+ // Test conditional with identity computations in both true and false cases.
+ //
+ // true_computation(F32[] %true_param):
+ // return %true_param
+ //
+ // false_computation(F32[] %false_param):
+ // return %false_param
+ //
+ // entry:
+ // %pred = Constant(true)
+ // %constant1 = Constant(56.0)
+ // %constant2 = Constant(12.0)
+ // return Conditional(%pred, %constant1, true_computation,
+ // %constant2, false_computation)
+
+ auto true_builder = HloComputation::Builder(TestName() + "_true");
+ auto true_param = true_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
+ HloComputation* true_computation =
+ module_->AddEmbeddedComputation(true_builder.Build());
+
+ auto false_builder = HloComputation::Builder(TestName() + "_false");
+ auto false_param = false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
+ HloComputation* false_computation =
+ module_->AddEmbeddedComputation(false_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
+ scalar_shape_, pred, constant1, true_computation, constant2,
+ false_computation));
+ module_->AddEntryComputation(builder.Build());
+
+ const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
+
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
+
+ EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
+ analysis.GetValueDefinedAt(constant2));
+
+ EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
+ ElementsAre(HloUse{conditional, 0, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
+ ElementsAre(HloUse{conditional, 1, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
+ ElementsAre(HloUse{conditional, 2, {}}));
+
+ EXPECT_EQ(analysis.values().size(), 3);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+ EXPECT_THAT(HloValuesAt(conditional),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
+ analysis.GetValueDefinedAt(constant2)));
+}
+
+TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
+ // Test conditional with true and false computations taking a tuple operand.
+ //
+ // true_computation((F32[], F32[]) %true_param):
+ // %true_x = GetTupleElement(%true_param, 0)
+ // %true_y = GetTupleElement(%true_param, 1)
+ // return Add(%true_x, %true_y)
+ //
+ // false_computation((F32[], F32[]) %false_param):
+ // %false_x = GetTupleElement(%false_param, 0)
+ // %false_y = GetTupleElement(%false_param, 1)
+ // return Subtract(%false_x, %false_y)
+ //
+ // entry:
+ // %pred = Constant(true)
+ // %constant1 = Constant(56.0)
+ // %constant2 = Constant(12.0)
+ // %tuple_operand = Tuple(%constant1, %constant2)
+ // return Conditional(%pred, %tuple_operand, true_computation,
+ // %tuple_operand, false_computation)
+
+ auto true_builder = HloComputation::Builder(TestName() + "_true");
+ auto true_param = true_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
+ auto true_x = true_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
+ auto true_y = true_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
+ auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kAdd, true_x, true_y));
+ HloComputation* true_computation =
+ module_->AddEmbeddedComputation(true_builder.Build());
+
+ auto false_builder = HloComputation::Builder(TestName() + "_false");
+ auto false_param = false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
+ auto false_x = false_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
+ auto false_y = false_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
+ auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
+ HloComputation* false_computation =
+ module_->AddEmbeddedComputation(false_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ auto tuple_operand = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
+ scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
+ false_computation));
+ module_->AddEntryComputation(builder.Build());
+
+ const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
+
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
+
+ EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
+ analysis.GetValueDefinedAt(tuple_operand));
+ EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
+ analysis.GetValueDefinedAt(tuple_operand));
+ EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
+ analysis.GetValueDefinedAt(constant2));
+ EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
+ analysis.GetValueDefinedAt(constant2));
+
+ EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
+ ElementsAre(HloUse{conditional, 0, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
+ UnorderedElementsAre(HloUse{conditional, 1, {0}},
+ HloUse{conditional, 2, {0}},
+ HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
+ UnorderedElementsAre(HloUse{conditional, 1, {1}},
+ HloUse{conditional, 2, {1}},
+ HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(),
+ UnorderedElementsAre(
+ HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
+ HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
+ HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
+
+ EXPECT_EQ(analysis.values().size(), 6);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+ EXPECT_THAT(HloValuesAt(conditional),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(add),
+ analysis.GetValueDefinedAt(sub)));
+}
+
+TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
+ // computation1(F32[] %param1):
+ // %ceil = Ceil(%param1)
+ // return %ceil
+ //
+ // computation2(F32[] %param2):
+ // %floor = Floor(%param2)
+ // return %floor
+ //
+ // computation3(F32[] %param3):
+ // %negate = Negate(%param3)
+ // return %negate
+ //
+ // inner_conditional((PRED, F32[], F32[]) %param_cond):
+ // %pred_cond = GetTupleElement(%param_cond, 0)
+ // %true_operand_cond = GetTupleElement(%param_cond, 1)
+ // %false_opearnd_cond = GetTupleElement(%param_cond, 2)
+ // return Conditional(%pred_cond, %true_operand_cond, computation1,
+ // %false_operand_cond, computation2)
+ //
+ // entry:
+ // %pred1 = Constant(true)
+ // %pred2 = Constant(false)
+ // %constant1 = Constant(1.1);
+ // %constant2 = Constant(2.2);
+ // %constant3 = Constant(3.3);
+ // return Conditional(%pred1, (%pred2, %constant1, %constant2),
+ // inner_conditional, %constant3, computation3)
+
+ auto computation1 = module_->AddEmbeddedComputation(
+ CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
+ auto computation2 = module_->AddEmbeddedComputation(
+ CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
+ auto computation3 = module_->AddEmbeddedComputation(
+ CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
+
+ // Build inner_conditional computation.
+ const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
+ const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
+ {scalar_bool_shape, scalar_shape_, scalar_shape_});
+ auto inner_builder =
+ HloComputation::Builder(TestName() + "_inner_conditional");
+ auto param_cond = inner_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
+ auto pred_cond = inner_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
+ auto true_operand_cond = inner_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
+ auto false_operand_cond = inner_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
+ auto inner_conditional =
+ inner_builder.AddInstruction(HloInstruction::CreateConditional(
+ scalar_shape_, pred_cond, true_operand_cond, computation1,
+ false_operand_cond, computation2));
+ auto inner_conditional_computation =
+ module_->AddEmbeddedComputation(inner_builder.Build());
+
+ // Build entry computation.
+ auto builder = HloComputation::Builder(TestName());
+ auto pred1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ auto pred2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.2f)));
+ auto constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(3.3f)));
+ auto tuple_operand = builder.AddInstruction(
+ HloInstruction::CreateTuple({pred2, constant1, constant2}));
+ auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
+ scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
+ constant3, computation3));
+ module_->AddEntryComputation(builder.Build());
+
+ const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
+
+ auto computation1_param = computation1->parameter_instruction(0);
+ auto computation2_param = computation2->parameter_instruction(0);
+ auto computation3_param = computation3->parameter_instruction(0);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
+ EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
+ analysis.GetValueDefinedAt(constant2));
+ EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
+ analysis.GetValueDefinedAt(constant3));
+
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
+ EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
+ analysis.GetValueDefinedAt(tuple_operand));
+ EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
+ analysis.GetValueDefinedAt(pred2));
+ EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
+ analysis.GetValueDefinedAt(constant2));
+
+ EXPECT_EQ(analysis.values().size(), 9);
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
+ EXPECT_THAT(
+ HloValuesAt(inner_conditional),
+ UnorderedElementsAre(
+ analysis.GetValueDefinedAt(computation1->root_instruction()),
+ analysis.GetValueDefinedAt(computation2->root_instruction())));
+ EXPECT_THAT(
+ HloValuesAt(conditional),
+ UnorderedElementsAre(
+ analysis.GetValueDefinedAt(computation1->root_instruction()),
+ analysis.GetValueDefinedAt(computation2->root_instruction()),
+ analysis.GetValueDefinedAt(computation3->root_instruction())));
+}
+
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));