diff options
author | 2018-08-29 14:51:13 -0700 | |
---|---|---|
committer | 2018-08-29 14:55:33 -0700 | |
commit | 866b44882cfa453464df9ac0fa71a53590f7c623 (patch) | |
tree | 5b51e76c9f6f9a1c680832593752b3471c83e9e1 /tensorflow/compiler/xla/service/instruction_fusion.cc | |
parent | 91c7cd5676624d8c364d7dc56bb50300bb9d210c (diff) |
[XLA] Update InstructionFusion::EffectivelyAtMostUnary for kIota
It should behave like kBroadcast with respect to it being "effectively unary."
PiperOrigin-RevId: 210795483
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.cc | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 83313c7ec1..4b5285031b 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -172,7 +172,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { }); return std::count_if(hlo->operands().begin(), hlo->operands().end(), [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { return false; } if (operand->opcode() == HloOpcode::kConstant && |