diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c22ee03388..fad3b14ec2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1042,7 +1042,10 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) { // not check result shape as that is checked in the ShapeVerifier. class InstructionVerifier : public DfsHloVisitorWithDefault { public: - InstructionVerifier() {} + explicit InstructionVerifier(std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} Status DefaultAction(HloInstruction*) override { return Status::OK(); } @@ -1143,8 +1146,34 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " + << ShapeUtil::HumanString(result_shape) << " To " + << ShapeUtil::HumanString(operand_shape); + } + } + } + + return Status::OK(); + } + private: absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func_; }; } // namespace @@ -1158,7 +1187,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); - InstructionVerifier instruction_verifier; + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } |