using System; using System.Collections.Generic; using System.Linq; using System.Text; using Microsoft.Boogie; using System.Diagnostics.Contracts; using System.Diagnostics; namespace Microsoft.Boogie { public enum MoverType { Top, Atomic, Right, Left, Both } public class ActionInfo { public Procedure proc; public MoverType moverType; public int phaseNum; public HashSet callerPhaseNums; public List thisGate; public CodeExpr thisAction; public List thisInParams; public List thisOutParams; public List thatGate; public CodeExpr thatAction; public List thatInParams; public List thatOutParams; public bool IsRightMover { get { return moverType == MoverType.Right || moverType == MoverType.Both; } } public bool IsLeftMover { get { return moverType == MoverType.Left || moverType == MoverType.Both; } } public ActionInfo(Procedure proc, CodeExpr codeExpr, MoverType moverType, int phaseNum) { this.proc = proc; this.moverType = moverType; this.phaseNum = phaseNum; this.callerPhaseNums = new HashSet(); this.thisGate = new List(); this.thisAction = codeExpr; this.thisInParams = new List(); this.thisOutParams = new List(); this.thatGate = new List(); this.thatInParams = new List(); this.thatOutParams = new List(); var cmds = thisAction.Blocks[0].Cmds; for (int i = 0; i < cmds.Count; i++) { AssertCmd assertCmd = cmds[i] as AssertCmd; if (assertCmd == null) break; thisGate.Add(assertCmd); cmds[i] = new AssumeCmd(assertCmd.tok, assertCmd.Expr); } Dictionary map = new Dictionary(); foreach (Variable x in proc.InParams) { this.thisInParams.Add(x); Variable y = new Formal(Token.NoToken, new TypedIdent(Token.NoToken, "that_" + x.Name, x.TypedIdent.Type), true); this.thatInParams.Add(y); map[x] = new IdentifierExpr(Token.NoToken, y); } foreach (Variable x in proc.OutParams) { this.thisOutParams.Add(x); Variable y = new Formal(Token.NoToken, new TypedIdent(Token.NoToken, "that_" + x.Name, x.TypedIdent.Type), false); this.thatOutParams.Add(y); map[x] = new IdentifierExpr(Token.NoToken, y); } List otherLocVars = new List(); foreach (Variable x in thisAction.LocVars) { Variable y = new Formal(Token.NoToken, new TypedIdent(Token.NoToken, "that_" + x.Name, x.TypedIdent.Type), false); map[x] = new IdentifierExpr(Token.NoToken, y); otherLocVars.Add(y); } Contract.Assume(proc.TypeParameters.Count == 0); Substitution subst = Substituter.SubstitutionFromHashtable(map); foreach (AssertCmd assertCmd in thisGate) { thatGate.Add((AssertCmd)Substituter.Apply(subst, assertCmd)); } Dictionary blockMap = new Dictionary(); List otherBlocks = new List(); foreach (Block block in thisAction.Blocks) { List otherCmds = new List(); foreach (Cmd cmd in block.Cmds) { otherCmds.Add(Substituter.Apply(subst, cmd)); } Block otherBlock = new Block(); otherBlock.Cmds = otherCmds; otherBlock.Label = "that_" + block.Label; block.Label = "this_" + block.Label; otherBlocks.Add(otherBlock); blockMap[block] = otherBlock; if (block.TransferCmd is GotoCmd) { GotoCmd gotoCmd = block.TransferCmd as GotoCmd; for (int i = 0; i < gotoCmd.labelNames.Count; i++) { gotoCmd.labelNames[i] = "this_" + gotoCmd.labelNames[i]; } } } foreach (Block block in thisAction.Blocks) { if (block.TransferCmd is ReturnExprCmd) { blockMap[block].TransferCmd = new ReturnCmd(block.TransferCmd.tok); continue; } List otherGotoCmdLabelTargets = new List(); List otherGotoCmdLabelNames = new List(); GotoCmd gotoCmd = block.TransferCmd as GotoCmd; foreach (Block target in gotoCmd.labelTargets) { otherGotoCmdLabelTargets.Add(blockMap[target]); otherGotoCmdLabelNames.Add(blockMap[target].Label); } blockMap[block].TransferCmd = new GotoCmd(block.TransferCmd.tok, otherGotoCmdLabelNames, otherGotoCmdLabelTargets); } this.thatAction = new CodeExpr(otherLocVars, otherBlocks); } } public class MoverTypeChecker : StandardVisitor { public int FindPhaseNumber(Procedure proc) { if (procToActionInfo.ContainsKey(proc)) return procToActionInfo[proc].phaseNum; else return int.MaxValue; } CheckingContext checkingContext; public int errorCount; HashSet globalVariables; int enclosingProcPhaseNum; public Dictionary procToActionInfo; public Program program; public HashSet assertionPhaseNums; public void TypeCheck() { foreach (Declaration decl in program.TopLevelDeclarations) { Procedure proc = decl as Procedure; if (proc == null) continue; foreach (Ensures e in proc.Ensures) { int phaseNum; MoverType moverType = MoverCheck.GetMoverType(e, out phaseNum); if (moverType == MoverType.Top) continue; CodeExpr codeExpr = e.Condition as CodeExpr; if (codeExpr == null) { Error(e, "An atomic action must be a CodeExpr"); continue; } if (procToActionInfo.ContainsKey(proc)) { Error(proc, "A procedure can have at most one atomic action"); continue; } procToActionInfo[proc] = new ActionInfo(proc, codeExpr, moverType, phaseNum); } } this.VisitProgram(program); #if QED YieldTypeChecker.PerformYieldTypeChecking(this); #endif } public MoverTypeChecker(Program program) { this.globalVariables = new HashSet(); foreach (var g in program.GlobalVariables()) { if (QKeyValue.FindBoolAttribute(g.Attributes, "qed")) this.globalVariables.Add(g); } this.procToActionInfo = new Dictionary(); this.assertionPhaseNums = new HashSet(); this.errorCount = 0; this.checkingContext = new CheckingContext(null); this.program = program; this.enclosingProcPhaseNum = int.MaxValue; } public override Implementation VisitImplementation(Implementation node) { enclosingProcPhaseNum = FindPhaseNumber(node.Proc); return base.VisitImplementation(node); } public override Procedure VisitProcedure(Procedure node) { enclosingProcPhaseNum = FindPhaseNumber(node); return base.VisitProcedure(node); } public override Cmd VisitCallCmd(CallCmd node) { if (!node.IsAsync) { int calleePhaseNum = FindPhaseNumber(node.Proc); if (enclosingProcPhaseNum > calleePhaseNum) { procToActionInfo[node.Proc].callerPhaseNums.Add(enclosingProcPhaseNum); } else if (enclosingProcPhaseNum < calleePhaseNum || enclosingProcPhaseNum != int.MaxValue) { Error(node, "The phase of the caller procedure must be greater than the phase of the callee"); } } return base.VisitCallCmd(node); } public override Cmd VisitParCallCmd(ParCallCmd node) { int maxCalleePhaseNum = 0; foreach (CallCmd iter in node.CallCmds) { int calleePhaseNum = FindPhaseNumber(iter.Proc); if (calleePhaseNum > maxCalleePhaseNum) maxCalleePhaseNum = calleePhaseNum; } if (enclosingProcPhaseNum > maxCalleePhaseNum) { bool isLeftMover = true; bool isRightMover = true; foreach (CallCmd iter in node.CallCmds) { ActionInfo actionInfo = procToActionInfo[iter.Proc]; isLeftMover = isLeftMover && actionInfo.IsLeftMover; isRightMover = isRightMover && actionInfo.IsRightMover; actionInfo.callerPhaseNums.Add(enclosingProcPhaseNum); } if (!isLeftMover && !isRightMover && node.CallCmds.Count > 1) { Error(node, "The procedures in the parallel call must be all right mover or all left mover"); } } return node; } public override Expr VisitIdentifierExpr(IdentifierExpr node) { if (globalVariables.Contains(node.Decl)) { Error(node, "Cannot access global variable"); } return base.VisitIdentifierExpr(node); } public override Ensures VisitEnsures(Ensures ensures) { if (ensures.IsAtomicSpecification) return ensures; int phaseNum = QKeyValue.FindIntAttribute(ensures.Attributes, "phase", int.MaxValue); assertionPhaseNums.Add(phaseNum); if (phaseNum > enclosingProcPhaseNum) { Error(ensures, "The phase of ensures clause cannot be greater than the phase of enclosing procedure"); } return ensures; } public override Requires VisitRequires(Requires requires) { int phaseNum = QKeyValue.FindIntAttribute(requires.Attributes, "phase", int.MaxValue); assertionPhaseNums.Add(phaseNum); if (phaseNum > enclosingProcPhaseNum) { Error(requires, "The phase of requires clause cannot be greater than the phase of enclosing procedure"); } return requires; } public override Cmd VisitAssertCmd(AssertCmd node) { int phaseNum = QKeyValue.FindIntAttribute(node.Attributes, "phase", int.MaxValue); assertionPhaseNums.Add(phaseNum); if (phaseNum > enclosingProcPhaseNum) { Error(node, "The phase of assert cannot be greater than the phase of enclosing procedure"); } return node; } public void Error(Absy node, string message) { checkingContext.Error(node, message); errorCount++; } } }