diff options
author | David Majnemer <majnemer@google.com> | 2018-08-28 13:56:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 14:01:14 -0700 |
commit | 96de2f020130fc0afbcc5087b520aa46ced5b5af (patch) | |
tree | 759a2b4d58448a1cd816a3b84979dc6d40d36467 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 9c1f14322484e44a93b77619ffd2e24b9b7a9b1d (diff) |
[XLA] Implement kIota for CPU & GPU, extend it w/ broadcast semantics
This extends the Iota HLO to have a broadcast field. This allows for higher
rank kIota operations.
PiperOrigin-RevId: 210600435
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index b93c758937..ffc74cfedd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2155,4 +2155,34 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl( scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloIotaInstruction&>(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); +} + } // namespace xla |