aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-28 13:56:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 14:01:14 -0700
commit96de2f020130fc0afbcc5087b520aa46ced5b5af (patch)
tree759a2b4d58448a1cd816a3b84979dc6d40d36467 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent9c1f14322484e44a93b77619ffd2e24b9b7a9b1d (diff)
[XLA] Implement kIota for CPU & GPU, extend it w/ broadcast semantics
This extends the Iota HLO to have a broadcast field. This allows for higher rank kIota operations. PiperOrigin-RevId: 210600435
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index b93c758937..ffc74cfedd 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2155,4 +2155,34 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
scatter_dimension_numbers());
}
+HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
+ : HloInstruction(HloOpcode::kIota, shape),
+ iota_dimension_(iota_dimension) {}
+
+HloInstructionProto HloIotaInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.add_dimensions(iota_dimension());
+ return proto;
+}
+
+std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("iota_dimension=", iota_dimension())};
+}
+
+bool HloIotaInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloIotaInstruction&>(other);
+ return iota_dimension() == casted_other.iota_dimension();
+}
+
+std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
+}
+
} // namespace xla