aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h32
1 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e169604072..4c5fc759a3 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1053,10 +1053,19 @@ class HloSelectAndScatterInstruction : public HloInstruction {
class HloCustomCallInstruction : public HloInstruction {
public:
- explicit HloCustomCallInstruction(const Shape& shape,
- absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target,
- absl::string_view opaque);
+ HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
+
+ // Constructor for a custom call with constrained layout. 'shape' and
+ // 'operands_with_layout' must all have layouts.
+ HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::string_view opaque,
+ absl::Span<const Shape> operand_shapes_with_layout);
+
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1085,6 +1094,16 @@ class HloCustomCallInstruction : public HloInstruction {
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns whether the result and operand layouts are constrained.
+ bool layout_constrained() const { return layout_constrained_; }
+
+ // Returns the shapes (with layout) of the operands. CHECKs if this custom
+ // call does not have constrained layouts.
+ const std::vector<Shape>& operand_shapes_with_layout() const {
+ CHECK(layout_constrained());
+ return operand_shapes_with_layout_;
+ }
+
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
@@ -1106,6 +1125,11 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
// The number of feature groups. This is used for grouped convolutions.
int64 feature_group_count_;
+ // Whether the result and operand layouts are constrained.
+ bool layout_constrained_;
+ // For layout-constrained custom calls, this vector holds the shape with
+ // layout for each operand.
+ std::vector<Shape> operand_shapes_with_layout_;
};
class HloPadInstruction : public HloInstruction {