diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e169604072..4c5fc759a3 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1053,10 +1053,19 @@ class HloSelectAndScatterInstruction : public HloInstruction { class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction(const Shape& shape, - absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target, - absl::string_view opaque); + HloCustomCallInstruction(const Shape& shape, + absl::Span<HloInstruction* const> operands, + absl::string_view custom_call_target, + absl::string_view opaque); + + // Constructor for a custom call with constrained layout. 'shape' and + // 'operands_with_layout' must all have layouts. + HloCustomCallInstruction(const Shape& shape, + absl::Span<HloInstruction* const> operands, + absl::string_view custom_call_target, + absl::string_view opaque, + absl::Span<const Shape> operand_shapes_with_layout); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1085,6 +1094,16 @@ class HloCustomCallInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns whether the result and operand layouts are constrained. + bool layout_constrained() const { return layout_constrained_; } + + // Returns the shapes (with layout) of the operands. CHECKs if this custom + // call does not have constrained layouts. + const std::vector<Shape>& operand_shapes_with_layout() const { + CHECK(layout_constrained()); + return operand_shapes_with_layout_; + } + private: std::vector<string> ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1106,6 +1125,11 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + // Whether the result and operand layouts are constrained. + bool layout_constrained_; + // For layout-constrained custom calls, this vector holds the shape with + // layout for each operand. + std::vector<Shape> operand_shapes_with_layout_; }; class HloPadInstruction : public HloInstruction { |