diff options
author | 2018-08-29 17:48:09 -0700 | |
---|---|---|
committer | 2018-08-29 17:51:38 -0700 | |
commit | 7cda8c3a8ad528f2e11fc47b0abf08e01f97af45 (patch) | |
tree | 4abceb5e1c3ca6692f41d53d71f8e4de1c4108fb /tensorflow/compiler/xla/service/hlo_element_type_converter.cc | |
parent | e528493c8cde468451ba1b1995e649ebe9c29b02 (diff) |
[XLA] Switch to using kIota from TF
We were using a broadcast of a constant instead of the kIota HLO.
To make switching to kIota practical, we need to do a few things first:
- Don't constant fold kIota.
- Don't hoist kIota from loops without good cause.
PiperOrigin-RevId: 210825834
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_element_type_converter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_element_type_converter.cc | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index b9244b8e9e..72006e17e7 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -151,7 +151,11 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) { } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); - if (!HasOperandType(hlo, eliminate_type_)) { + bool nullary = hlo->operands().empty(); + bool wrong_element_type = hlo->shape().element_type() == eliminate_type_; + bool should_eliminate_type = (nullary && wrong_element_type) || + HasOperandType(hlo, eliminate_type_); + if (!should_eliminate_type) { // If this CHECK fires, then this was an instruction that does not take // the elimination type as an operand but it does return it. This pass // does not have a feature to change the output type in that case, so |