aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_verifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc34
1 files changed, 32 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index c22ee03388..fad3b14ec2 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1042,7 +1042,10 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) {
// not check result shape as that is checked in the ShapeVerifier.
class InstructionVerifier : public DfsHloVisitorWithDefault {
public:
- InstructionVerifier() {}
+ explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
+ : instruction_can_change_layout_func_(
+ instruction_can_change_layout_func) {}
Status DefaultAction(HloInstruction*) override { return Status::OK(); }
@@ -1143,8 +1146,34 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ Status Postprocess(HloInstruction* instruction) override {
+ if (instruction_can_change_layout_func_ &&
+ LayoutUtil::IsDenseArray(instruction->shape()) &&
+ !instruction_can_change_layout_func_(instruction)) {
+ const Shape& result_shape = instruction->shape();
+ const Layout& result_layout = result_shape.layout();
+ for (HloInstruction* operand : instruction->operands()) {
+ const Shape& operand_shape = operand->shape();
+ if (LayoutUtil::IsDenseArray(operand_shape) &&
+ ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) {
+ const Layout& operand_layout = operand_shape.layout();
+ TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
+ << "Instruction shouldn't change layouts "
+ << instruction->ToString() << " From "
+ << ShapeUtil::HumanString(result_shape) << " To "
+ << ShapeUtil::HumanString(operand_shape);
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
private:
absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
+ // Determines whether an instruction can change layouts.
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
};
} // namespace
@@ -1158,7 +1187,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
- InstructionVerifier instruction_verifier;
+ InstructionVerifier instruction_verifier(
+ instruction_can_change_layout_func_);
TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}