aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/buffer_assignment_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-07 17:46:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 17:50:07 -0800
commit0e9cc7f3113ade82436729bd541f6b501d023ac0 (patch)
tree797d2a0867bba92008d93d9f6cc416bb3b9f8e57 /tensorflow/compiler/xla/service/buffer_assignment_test.cc
parent1667d4dcd2c7c33a3bcade62014931a1f8d9a2e0 (diff)
[XLA] Implement Conditional in XLA service, client ComputationBuilder, and CPU backend.
PiperOrigin-RevId: 178322445
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc59
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 09681b34e7..6fc9d783f1 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase {
return builder.Build();
}
+ std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
+ HloOpcode opcode, const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
+ return builder.Build();
+ }
+
// Verifies that the given instruction hlo has a valid input buffer assigned,
// i.e., the parameter number matches the op's.
const BufferAllocation& GetAssignedInputAllocation(
@@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
<< " instructions; total buffer size " << size0 + sizec + sizeb;
}
+TEST_F(BufferAssignmentTest, ExampleConditional) {
+ auto module = CreateNewModule();
+ auto true_computation = module->AddEmbeddedComputation(
+ BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
+ auto false_computation = module->AddEmbeddedComputation(
+ BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
+
+ auto builder = HloComputation::Builder(TestName());
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
+ auto const2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f)));
+ auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
+ r0f32_, pred, const1, true_computation, const2, false_computation));
+ module->AddEntryComputation(builder.Build());
+
+ const std::vector<const HloInstruction*> conditional_instrs =
+ GetInstructions(conditional);
+ const std::vector<const HloInstruction*> true_instrs =
+ GetInstructions(true_computation->root_instruction());
+ const std::vector<const HloInstruction*> false_instrs =
+ GetInstructions(false_computation->root_instruction());
+ EXPECT_EQ(4, conditional_instrs.size());
+ EXPECT_EQ(2, true_instrs.size());
+ EXPECT_EQ(2, false_instrs.size());
+
+ auto buffers = RunBufferAssignment(module.get());
+ ValidateBuffers(conditional_instrs, *buffers);
+ ValidateBuffers(true_instrs, *buffers);
+ ValidateBuffers(false_instrs, *buffers);
+
+ EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
+ << "Should be reuse between conditional and true computation.";
+ EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
+ << "Should be reuse between conditional and false computation.";
+ EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
+ << "Should be reuse between true and false computations.";
+
+ const BufferAllocation& conditional_buffer =
+ GetTopLevelAllocation(*buffers, conditional);
+ const BufferAllocation& true_buffer =
+ GetTopLevelAllocation(*buffers, true_computation->root_instruction());
+ const BufferAllocation& false_buffer =
+ GetTopLevelAllocation(*buffers, false_computation->root_instruction());
+ EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
+ EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
+}
+
TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
// param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
auto builder = HloComputation::Builder(TestName());