diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-07 17:46:37 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-07 17:50:07 -0800 |
commit | 0e9cc7f3113ade82436729bd541f6b501d023ac0 (patch) | |
tree | 797d2a0867bba92008d93d9f6cc416bb3b9f8e57 /tensorflow/compiler/xla/service/buffer_assignment_test.cc | |
parent | 1667d4dcd2c7c33a3bcade62014931a1f8d9a2e0 (diff) |
[XLA] Implement Conditional in XLA service, client ComputationBuilder, and CPU backend.
PiperOrigin-RevId: 178322445
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment_test.cc | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 09681b34e7..6fc9d783f1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation( + HloOpcode opcode, const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); + return builder.Build(); + } + // Verifies that the given instruction hlo has a valid input buffer assigned, // i.e., the parameter number matches the op's. const BufferAllocation& GetAssignedInputAllocation( @@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { << " instructions; total buffer size " << size0 + sizec + sizeb; } +TEST_F(BufferAssignmentTest, ExampleConditional) { + auto module = CreateNewModule(); + auto true_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); + auto false_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f))); + auto const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + r0f32_, pred, const1, true_computation, const2, false_computation)); + module->AddEntryComputation(builder.Build()); + + const std::vector<const HloInstruction*> conditional_instrs = + GetInstructions(conditional); + const std::vector<const HloInstruction*> true_instrs = + GetInstructions(true_computation->root_instruction()); + const std::vector<const HloInstruction*> false_instrs = + GetInstructions(false_computation->root_instruction()); + EXPECT_EQ(4, conditional_instrs.size()); + EXPECT_EQ(2, true_instrs.size()); + EXPECT_EQ(2, false_instrs.size()); + + auto buffers = RunBufferAssignment(module.get()); + ValidateBuffers(conditional_instrs, *buffers); + ValidateBuffers(true_instrs, *buffers); + ValidateBuffers(false_instrs, *buffers); + + EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) + << "Should be reuse between conditional and true computation."; + EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) + << "Should be reuse between conditional and false computation."; + EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) + << "Should be reuse between true and false computations."; + + const BufferAllocation& conditional_buffer = + GetTopLevelAllocation(*buffers, conditional); + const BufferAllocation& true_buffer = + GetTopLevelAllocation(*buffers, true_computation->root_instruction()); + const BufferAllocation& false_buffer = + GetTopLevelAllocation(*buffers, false_computation->root_instruction()); + EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); + EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); +} + TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); |