aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction_test.cc
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-06-07 18:42:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 18:45:05 -0700
commita9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (patch)
tree849afa106cb882d918cfa0a633134bb89f7f014f /tensorflow/compiler/xla/service/hlo_instruction_test.cc
parent2f41346cbc0c8ecb915983a1f8711fd0d0ccc50e (diff)
[DataFlowAnalysis] Be less conservative on loop fusion nodes when reusing buffer.
- Previously, we say we cannot reuse operand buffer for a loop fusion node if any of the fusion's inputs is a broadcast or reshape. That's too conservative since in theory we can still reuse the operand's buffer if all the users of that particular operand are elementwise. This CL implements that. - Also fixed a bug in previous code where a dynamic update fusion node that ends with convert (added for bf16) is not caught by the if condition currectly. PiperOrigin-RevId: 199731488
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 313033ddad..76349c4099 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -980,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) {
}
}
+TEST_F(HloInstructionTest, MapIsElementwise) {
+ auto module = CreateNewModule();
+ const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
+ HloComputation::Builder builder(TestName());
+ HloComputation::Builder map_builder("id");
+ map_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
+ auto map = builder.AddInstruction(
+ HloInstruction::CreateMap(r2f32, {x}, map_computation));
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(map->IsElementwise());
+}
+
TEST_F(HloInstructionTest, PartiallyElementwise) {
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});