aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.h
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-28 13:56:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 14:01:14 -0700
commit96de2f020130fc0afbcc5087b520aa46ced5b5af (patch)
tree759a2b4d58448a1cd816a3b84979dc6d40d36467 /tensorflow/compiler/xla/service/hlo_instructions.h
parent9c1f14322484e44a93b77619ffd2e24b9b7a9b1d (diff)
[XLA] Implement kIota for CPU & GPU, extend it w/ broadcast semantics
This extends the Iota HLO to have a broadcast field. This allows for higher rank kIota operations. PiperOrigin-RevId: 210600435
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 29b187300d..ee6e337b6a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1279,6 +1279,30 @@ class HloScatterInstruction : public HloInstruction {
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
};
+class HloIotaInstruction : public HloInstruction {
+ public:
+ explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ int64 iota_dimension() const { return iota_dimension_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ const int64 iota_dimension_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_