diff options
author | 2018-06-13 13:28:20 -0700 | |
---|---|---|
committer | 2018-06-13 13:33:30 -0700 | |
commit | fbd920a6997e2d507b4247c194574a5b2b10f926 (patch) | |
tree | 0f67acae5fa56ae9afa14121417c7f68e0a15306 /tensorflow/compiler/xla/service/hlo_instructions.h | |
parent | 7b033a1c26670f99562ee6c8a86bfc2721101165 (diff) |
Split out HloInfeedIndexInstruction and HloOutfeedInstruction as subclasses from HloInstruction.
PiperOrigin-RevId: 200443508
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 6749d87555..9f810c0a14 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -722,6 +722,67 @@ class HloReducePrecisionInstruction : public HloInstruction { int32 exponent_bits_ = 0; int32 mantissa_bits_ = 0; }; + +class HloInfeedInstruction : public HloInstruction { + public: + explicit HloInfeedInstruction(const Shape& shape, const string& config); + // Returns the infeed configuration string. The infeed configuration includes + // any metadata needed for the backend compiler (e.g., infeed buffer address) + // and is target-dependent. + string infeed_config() const { return infeed_config_; } + void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + // The string representation of the infeed configuration. + string infeed_config_; +}; + +class HloOutfeedInstruction : public HloInstruction { + public: + explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return outfeed_shape_; + } + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const { return outfeed_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + // Shape of outfeed request. + Shape outfeed_shape_; + // Outfeed configuration information, only present for kOutfeed. + string outfeed_config_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ |