diff options
author | 2018-06-07 18:42:30 -0700 | |
---|---|---|
committer | 2018-06-07 18:45:05 -0700 | |
commit | a9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (patch) | |
tree | 849afa106cb882d918cfa0a633134bb89f7f014f /tensorflow/compiler/xla/service/hlo_instruction_test.cc | |
parent | 2f41346cbc0c8ecb915983a1f8711fd0d0ccc50e (diff) |
[DataFlowAnalysis] Be less conservative on loop fusion nodes when reusing buffer.
- Previously, we say we cannot reuse operand buffer for a loop fusion node if any of the fusion's inputs is a broadcast or reshape. That's too conservative since in theory we can still reuse the operand's buffer if all the users of that particular operand are elementwise. This CL implements that.
- Also fixed a bug in previous code where a dynamic update fusion node that ends with convert (added for bf16) is not caught by the if condition currectly.
PiperOrigin-RevId: 199731488
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 313033ddad..76349c4099 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -980,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); |