diff options
-rw-r--r-- | Chalice/src/Ast.scala | 120 | ||||
-rw-r--r-- | Chalice/src/Parser.scala | 39 | ||||
-rw-r--r-- | Chalice/src/PrettyPrinter.scala | 23 | ||||
-rw-r--r-- | Chalice/src/Resolver.scala | 65 |
4 files changed, 196 insertions, 51 deletions
diff --git a/Chalice/src/Ast.scala b/Chalice/src/Ast.scala index aaaf7fc9..f49acb2b 100644 --- a/Chalice/src/Ast.scala +++ b/Chalice/src/Ast.scala @@ -176,15 +176,39 @@ case class LockChange(ee: List[Expression]) extends Specification sealed abstract class Refinement(id: String) extends NamedMember(id) {
var refines = null: NamedMember;
}
-case class MethodTransform(id: String, ins: List[Variable], outs: List[Variable], spec: List[Specification], trans: Transform) extends Refinement(id)
+case class MethodTransform(id: String, ins: List[Variable], outs: List[Variable], spec: List[Specification], trans: Transform) extends Refinement(id) {
+ var body = null:List[Statement];
+}
sealed abstract class Transform extends ASTNode
-case class BlockPattern() extends Transform // pattern within a block (*)
-case class ProgramPattern() extends Transform // can match entire block (*)
-case class IfPattern(thn: Transform, els: Option[Transform]) extends Transform
-case class NonDetPattern(code: BlockStmt) extends Transform // matches var or spec
-case class InsertPattern(code: Statement) extends Transform
-case class SequencePattern(pats: List[Transform]) extends Transform
+/** Pattern matching within a block (zero or more) over deterministic statements */
+case class BlockPat() extends Transform {
+ def matches(s: Statement) = s match {
+ case _:Assert => true
+ case _:Assume => true
+ case _:Assign => true
+ case _:FieldUpdate => true
+ case _:LocalVar => true
+ case _ => false
+ }
+}
+/** Matches any block of code (greedily) and acts as identity */
+case class SkipPat() extends Transform
+/** Replacement pattern for arbitrary block */
+case class ProgramPat(code: List[Statement]) extends Transform
+case class IfPat(thn: Transform, els: Option[Transform]) extends Transform
+case class NonDetPat(is: List[String], code: List[Statement]) extends Transform {
+ def matches(s: Statement) = s match {
+ case _:Call => true
+ case _:SpecStmt => true
+ case _ => false
+ }
+}
+case class InsertPat(code: List[Statement]) extends Transform
+case class SeqPat(pats: List[Transform]) extends Transform {
+ assert(pats.size > 0)
+}
+case class RefinementBlock(ss: List[Statement], original: List[Statement]) extends Statement
/**
* Statements
@@ -460,6 +484,88 @@ case class CallState(token: Expression, obj: Expression, id: String, args: List[ object AST {
/**
+ * Flattens sequences of transforms and merges consecutive blocks
+ */
+ def normalize(trans: Transform): Transform = trans match {
+ case IfPat(thn, Some(els)) => IfPat(normalize(thn), Some(normalize(els)))
+ case IfPat(thn, None) => IfPat(normalize(thn), None)
+ case SeqPat(pats) =>
+ val rec = pats flatMap {pat => normalize(pat) match {
+ case SeqPat(pats) => pats;
+ case x => List(x)
+ }}
+ def noTwoBlocks: List[Transform] => List[Transform] = {
+ case BlockPat() :: BlockPat() :: l => noTwoBlocks(BlockPat() :: l)
+ case x :: l => x :: noTwoBlocks(l)
+ case Nil => Nil
+ }
+ SeqPat(noTwoBlocks(rec))
+ case _ => trans
+ }
+
+ sealed abstract class TransformMatch
+ case class Matched(ss: List[Statement]) extends TransformMatch {
+ def this(s: Statement) = this(List(s))
+ }
+ case class Unmatched(t: Transform) extends TransformMatch
+
+ /**
+ * Matches a proper block to a transform.
+ * Requires: a sequence pattern should not contain a sequence pattern
+ */
+ def refine:(List[Statement], Transform) => TransformMatch = {
+ // order is important!
+ // whole program
+ case (l, ProgramPat(code)) => new Matched(RefinementBlock(code, l))
+ // if pattern
+ case (List(IfStmt(guard, thn, None)), t @ IfPat(thnT, None)) =>
+ refine(thn.ss, thnT) match {
+ case Matched(thn0) => new Matched(IfStmt(guard, BlockStmt(thn0), None))
+ case _ => Unmatched(t)
+ }
+ case (List(IfStmt(guard, thn, Some(els))), t @ IfPat(thnT, Some(elsT))) =>
+ (refine(thn.ss, thnT), refine(List(els), elsT)) match {
+ case (Matched(thn0), Matched(els0)) => new Matched(IfStmt(guard, BlockStmt(thn0), Some(BlockStmt(els0))))
+ case _ => Unmatched(t)
+ }
+ // non det pat
+ case (l @ List(_: Call), NonDetPat(_, code)) => new Matched(RefinementBlock(code, l))
+ case (l @ List(_: SpecStmt), NonDetPat(_, code)) => new Matched(RefinementBlock(code, l))
+ // insert pat
+ case (Nil, InsertPat(code)) => new Matched(RefinementBlock(code, Nil))
+ // reduction of base cases
+ case (l, SeqPat(List(t))) => refine(l, t)
+ case (List(BlockStmt(ss)), t) => refine(ss, t)
+ // block pattern (greedy matching)
+ case (l, bp @ BlockPat()) if (l forall {s => bp matches s}) => Matched(l)
+ case (s :: ss, SeqPat((bp @ BlockPat()) :: ts)) if (bp matches s) =>
+ refine(ss, SeqPat(ts)) match {
+ case Matched(l) => Matched(s :: l)
+ case x => x
+ }
+ case (l, SeqPat((bp @ BlockPat()) :: ts)) if (l.size == 0 || !(bp matches l.head)) =>
+ refine(l, SeqPat(ts))
+ case (l, SkipPat()) => Matched(l)
+ // sequence pattern
+ case (s :: ss, SeqPat((np: NonDetPat) :: ts)) =>
+ (refine(List(s), np), refine(ss, SeqPat(ts))) match {
+ case (Matched(a), Matched(b)) => Matched(a ::: b)
+ case _ => Unmatched(np)
+ }
+ case (s :: ss, SeqPat((ip: IfPat) :: ts)) =>
+ (refine(List(s), ip), refine(ss, SeqPat(ts))) match {
+ case (Matched(a), Matched(b)) => Matched(a ::: b)
+ case _ => Unmatched(ip)
+ }
+ case (l, SeqPat(InsertPat(code) :: ts)) =>
+ refine(l, SeqPat(ts)) match {
+ case Matched(a) => Matched(RefinementBlock(code, Nil) :: a)
+ case x => x
+ }
+ case (_, t) => Unmatched(t)
+ }
+
+ /**
* Transforms an expression using f. f must produce expressions of the appropriate type (e.g. not replace int literal with a bool literal)
* Ensures that mutable fields of expressions are carried over. f must make sure that mutable fields of its value are filled in.
*/
diff --git a/Chalice/src/Parser.scala b/Chalice/src/Parser.scala index 15b0e33f..6bcecac0 100644 --- a/Chalice/src/Parser.scala +++ b/Chalice/src/Parser.scala @@ -13,8 +13,6 @@ import scala.util.parsing.input.NoPosition import java.io.File
class Parser extends StandardTokenParsers {
-
-
def parseStdin = phrase(programUnit)(new lexical.Scanner(new PagedSeqReader(PagedSeq fromReader Console.in)))
def parseFile(file: File) = phrase(programUnit)(new lexical.Scanner(new PagedSeqReader(PagedSeq fromFile file)))
@@ -33,7 +31,7 @@ class Parser extends StandardTokenParsers { "ite", "fold", "unfold", "unfolding", "in", "forall", "exists",
"seq", "nil", "result", "eval", "token",
"wait", "signal",
- "refines", "transforms"
+ "refines", "transforms", "replaces", "by"
)
// todo: can we remove "nil"?
lexical.delimiters += ("(", ")", "{", "}", "[[", "]]",
@@ -47,6 +45,7 @@ class Parser extends StandardTokenParsers { def programUnit = (classDecl | channelDecl)*
def Semi = ";" ?
var currentLocalVariables = Set[String]() // used in the method context
+ var assumeAllLocals = false;
/**
* Top level declarations
@@ -71,24 +70,25 @@ class Parser extends StandardTokenParsers { * Member declarations
*/
- def memberDecl = positioned(fieldDecl | invariantDecl | methodDecl | conditionDecl | predicateDecl | functionDecl | transformDecl)
+ def memberDecl = {
+ currentLocalVariables = Set[String](); assumeAllLocals = false;
+ positioned(fieldDecl | invariantDecl | methodDecl | conditionDecl | predicateDecl | functionDecl | transformDecl)
+ }
def fieldDecl =
( "var" ~> idType <~ Semi ^^ { case (id,t) => Field(id.v, t, false) }
| "ghost" ~> "var" ~> idType <~ Semi ^^ { case (id,t) => Field(id.v, t, true) }
)
def invariantDecl = positioned("invariant" ~> expression <~ Semi ^^ MonitorInvariant)
- def methodDecl = {
- currentLocalVariables = Set[String]()
+ def methodDecl =
"method" ~> ident ~ formalParameters(true) ~ ("returns" ~> formalParameters(false) ?) ~
(methodSpec*) ~ blockStatement ^^ {
case id ~ ins ~ outs ~ spec ~ body =>
outs match {
case None => Method(id, ins, Nil, spec, body)
case Some(outs) => Method(id, ins, outs, spec, body) }}
- }
def predicateDecl: Parser[Predicate] =
("predicate" ~> ident) ~ ("{" ~> expression <~ "}") ^^ { case id ~ definition => Predicate(id, definition) }
- def functionDecl =
+ def functionDecl =
("function" ~> ident) ~ formalParameters(true) ~ (":" ~> typeDecl) ~ (methodSpec*) ~ opt("{" ~> expression <~ "}") ^^ {
case id ~ ins ~ out ~ specs ~ body => Function(id, ins, out, specs, body)
}
@@ -96,15 +96,14 @@ class Parser extends StandardTokenParsers { "condition" ~> ident ~ ("where" ~> expression ?) <~ Semi ^^ {
case id ~ optE => Condition(id, optE) }
def transformDecl = {
- currentLocalVariables = Set[String]()
+ assumeAllLocals = true;
"transforms" ~> ident ~ formalParameters(true) ~ ("returns" ~> formalParameters(false) ?) ~
(methodSpec*) ~ ("{" ~> transform <~ "}") ^^ {
case id ~ ins ~ outs ~ spec ~ trans =>
- MethodTransform(id, ins, outs match {case None => Nil; case Some(outs) => outs}, spec, trans)
+ MethodTransform(id, ins, outs match {case None => Nil; case Some(outs) => outs}, spec, AST.normalize(trans))
}
}
-
def formalParameters(immutable: Boolean) =
"(" ~> (formalList(immutable) ?) <~ ")" ^^ {
case None => Nil
@@ -182,7 +181,7 @@ class Parser extends StandardTokenParsers { | "free" ~> expression <~ Semi ^^ Free
| Ident ~ ":=" ~ Rhs <~ Semi ^^ {
case lhs ~ _ ~ rhs =>
- if (currentLocalVariables contains lhs.v) {
+ if ((currentLocalVariables contains lhs.v) || assumeAllLocals) {
val varExpr = VariableExpr(lhs.v); varExpr.pos = lhs.pos;
Assign(varExpr, rhs)
} else {
@@ -463,7 +462,7 @@ class Parser extends StandardTokenParsers { | positioned(Ident) ~ opt("(" ~> expressionList <~ ")") ^^ {
case id ~ None =>
val r =
- if (currentLocalVariables contains id.v) {
+ if ((currentLocalVariables contains id.v) || assumeAllLocals) {
VariableExpr(id.v)
} else {
val implicitThis = ImplicitThisExpr(); implicitThis.pos = id.pos
@@ -561,15 +560,19 @@ class Parser extends StandardTokenParsers { def transform: Parser[Transform] = positioned(
"if" ~> ifTransform
- | transformAtom ~ rep(transform) ^^ {case atom ~ t => SequencePattern(atom :: t)}
+ | transformAtom ~ rep(transform) ^^ {case atom ~ t => SeqPat(atom :: t)}
)
def transformAtom: Parser[Transform] = positioned(
- "_" ~ Semi ^^^ BlockPattern()
- | statement ^^ {case s => InsertPattern(s)}
+ "_" ~ Semi ^^^ BlockPat()
+ | "*" ~ Semi ^^^ SkipPat()
+ | "replaces" ~> rep1sep(Ident,",") ~ ("by" ~> blockStatement) ^^ {
+ case ids ~ code => NonDetPat(ids map {x => x.v}, code)
+ }
+ | rep1(statement) ^^ {case s => InsertPat(s)}
)
- def ifTransform: Parser[IfPattern] =
+ def ifTransform: Parser[IfPat] =
("{" ~> transform <~ "}") ~ ("else" ~> ifTransformElse ?) ^^ {
- case thn ~ els => IfPattern(thn, els)
+ case thn ~ els => IfPat(thn, els)
}
def ifTransformElse = (
"if" ~> ifTransform
diff --git a/Chalice/src/PrettyPrinter.scala b/Chalice/src/PrettyPrinter.scala index c10aa9ea..6eb8c9ff 100644 --- a/Chalice/src/PrettyPrinter.scala +++ b/Chalice/src/PrettyPrinter.scala @@ -24,10 +24,8 @@ object PrintProgram { case m: Method =>
print(" method " + m.id)
print("("); VarList(m.ins); print(")")
- if (m.outs != Nil) {
- print(" returns ("); VarList(m.outs); print(")")
- }
- println
+ if (m.outs != Nil) print(" returns ("); VarList(m.outs); print(")")
+ println;
m.spec foreach {
case Precondition(e) => print(" requires "); Expr(e); println(Semi)
case Postcondition(e) => print(" ensures "); Expr(e); println(Semi)
@@ -58,7 +56,20 @@ object PrintProgram { e match {
case Some(e) => print(" { "); Expr(e); println(" }");
case None =>
- }
+ }
+ case m: MethodTransform =>
+ print(" transforms " + m.id);
+ print("("); VarList(m.ins); print(")")
+ if (m.outs != Nil) print(" returns ("); VarList(m.outs); print(")")
+ println;
+ m.spec foreach {
+ case Precondition(e) => print(" requires "); Expr(e); println(Semi)
+ case Postcondition(e) => print(" ensures "); Expr(e); println(Semi)
+ }
+ println(" {");
+ throw new Exception("not yet implemented")
+ // TODO: print out transform
+ println(" }")
}
def Stmt(s: Statement, indent: Int): Unit = s match {
case Assert(e) =>
@@ -67,6 +78,8 @@ object PrintProgram { print("assume "); Expr(e); println(Semi)
case BlockStmt(ss) =>
PrintBlockStmt(ss, indent); println
+ case RefinementBlock(ss, _) =>
+ PrintBlockStmt(ss, indent); println
case IfStmt(guard, BlockStmt(thn), els) =>
print("if ("); Expr(guard); print(") ")
PrintBlockStmt(thn, indent)
diff --git a/Chalice/src/Resolver.scala b/Chalice/src/Resolver.scala index 80cabe62..37a0df60 100644 --- a/Chalice/src/Resolver.scala +++ b/Chalice/src/Resolver.scala @@ -97,21 +97,23 @@ object Resolver { }
// resolve refinement members
- for (List(cl) <- dag.computeTopologicalSort.reverse) {
- if (! cl.IsRefinement) {
- // check has no refinement members
- if (cl.members.exists{case _: Refinement => true; case _ => false})
- return Errors(List((cl.pos, "non-refinement class cannot have refinement members")))
- } else for (member <- cl.members) member match {
- case r: Refinement =>
- if (! cl.refines.LookupMember(r.Id).isDefined)
- return Errors(List((r.pos, "abstract class has no member with name " + r.Id)))
- r.refines = cl.refines.LookupMember(r.Id).get
- case m: NamedMember =>
- if (cl.refines.LookupMember(m.Id).isDefined)
- return Errors(List((m.pos, "member needs to be a refinement since abstract class has a member with the same name")))
- case _ =>
- }
+ for (decl <- prog) decl match {
+ case cl: Class =>
+ if (! cl.IsRefinement) {
+ // check has no refinement members
+ if (cl.members.exists{case _: Refinement => true; case _ => false})
+ return Errors(List((cl.pos, "non-refinement class cannot have refinement members")))
+ } else for (member <- cl.members) member match {
+ case r: Refinement =>
+ if (! cl.refines.LookupMember(r.Id).isDefined)
+ return Errors(List((r.pos, "abstract class has no member with name " + r.Id)))
+ r.refines = cl.refines.LookupMember(r.Id).get
+ case m: NamedMember =>
+ if (cl.refines.LookupMember(m.Id).isDefined)
+ return Errors(List((m.pos, "member needs to be a refinement since abstract class has a member with the same name")))
+ case _ =>
+ }
+ case _ =>
}
// collect errors
@@ -156,7 +158,7 @@ object Resolver { ctx = ctx.AddVariable(v)
}
ResolveExpr(ch.where, ctx, false, true)(false)
- errors = errors ++ context.errors
+ errors = errors ++ context.errors
case cl: Class =>
val context = new ProgramContext(decls, cl)
for (m <- cl.members) {
@@ -216,7 +218,6 @@ object Resolver { case None =>
}
case mt: MethodTransform =>
- ResolveTransform(mt, context)
}
}
errors = errors ++ context.errors
@@ -231,6 +232,18 @@ object Resolver { f.isRecursive = true;
}
+ // resolve refinement transforms
+ for (List(cl) <- dag.computeTopologicalSort.reverse) {
+ val context = new ProgramContext(decls, cl)
+ for (m <- cl.members) m match {
+ case mt: MethodTransform =>
+ context.currentMember = mt;
+ ResolveTransform(mt, context);
+ case _ =>
+ }
+ errors = errors ++ context.errors
+ }
+
if (errors.length == 0) {
Success()
} else {
@@ -1076,14 +1089,24 @@ object Resolver { }
def ResolveTransform(mt: MethodTransform, context: ProgramContext) {
- mt.refines match {
+ val orig = mt.refines match {
case m: Method =>
if (mt.ins != m.ins) context.Error(mt.pos, "Refinement must have same input arguments")
- if (! mt.outs.startsWith(m.outs)) context.Error(mt.pos, "Refinement must declare all concrete output variables")
+ if (! mt.outs.startsWith(m.outs)) context.Error(mt.pos, "Refinement must declare all abstract output variables")
+ m.body
case r: MethodTransform =>
if (mt.ins != r.ins) context.Error(mt.pos, "Refinement must have same input arguments")
- if (! mt.outs.startsWith(r.outs)) context.Error(mt.pos, "Refinement must declare all concrete output variables")
- case _ => context.Error(mt.pos, "Method can only refine another method or a transform")
+ if (! mt.outs.startsWith(r.outs)) context.Error(mt.pos, "Refinement must declare all abstract output variables")
+ assert(r.body != null)
+ r.body
+ case _ =>
+ context.Error(mt.pos, "Transform must refine a method or a transform")
+ Nil
+ }
+ mt.body = AST.refine(orig, mt.trans) match {
+ case AST.Matched(ss) => ss
+ case AST.Unmatched(t) => context.Error(mt.pos, "Cannot match transform around " + t); Nil
}
+ PrintProgram.Stmt(BlockStmt(mt.body), 0)
}
}
|