diff options
Diffstat (limited to 'Chalice/src/Ast.scala')
-rw-r--r-- | Chalice/src/Ast.scala | 120 |
1 files changed, 113 insertions, 7 deletions
diff --git a/Chalice/src/Ast.scala b/Chalice/src/Ast.scala index aaaf7fc9..f49acb2b 100644 --- a/Chalice/src/Ast.scala +++ b/Chalice/src/Ast.scala @@ -176,15 +176,39 @@ case class LockChange(ee: List[Expression]) extends Specification sealed abstract class Refinement(id: String) extends NamedMember(id) {
var refines = null: NamedMember;
}
-case class MethodTransform(id: String, ins: List[Variable], outs: List[Variable], spec: List[Specification], trans: Transform) extends Refinement(id)
+case class MethodTransform(id: String, ins: List[Variable], outs: List[Variable], spec: List[Specification], trans: Transform) extends Refinement(id) {
+ var body = null:List[Statement];
+}
sealed abstract class Transform extends ASTNode
-case class BlockPattern() extends Transform // pattern within a block (*)
-case class ProgramPattern() extends Transform // can match entire block (*)
-case class IfPattern(thn: Transform, els: Option[Transform]) extends Transform
-case class NonDetPattern(code: BlockStmt) extends Transform // matches var or spec
-case class InsertPattern(code: Statement) extends Transform
-case class SequencePattern(pats: List[Transform]) extends Transform
+/** Pattern matching within a block (zero or more) over deterministic statements */
+case class BlockPat() extends Transform {
+ def matches(s: Statement) = s match {
+ case _:Assert => true
+ case _:Assume => true
+ case _:Assign => true
+ case _:FieldUpdate => true
+ case _:LocalVar => true
+ case _ => false
+ }
+}
+/** Matches any block of code (greedily) and acts as identity */
+case class SkipPat() extends Transform
+/** Replacement pattern for arbitrary block */
+case class ProgramPat(code: List[Statement]) extends Transform
+case class IfPat(thn: Transform, els: Option[Transform]) extends Transform
+case class NonDetPat(is: List[String], code: List[Statement]) extends Transform {
+ def matches(s: Statement) = s match {
+ case _:Call => true
+ case _:SpecStmt => true
+ case _ => false
+ }
+}
+case class InsertPat(code: List[Statement]) extends Transform
+case class SeqPat(pats: List[Transform]) extends Transform {
+ assert(pats.size > 0)
+}
+case class RefinementBlock(ss: List[Statement], original: List[Statement]) extends Statement
/**
* Statements
@@ -460,6 +484,88 @@ case class CallState(token: Expression, obj: Expression, id: String, args: List[ object AST {
/**
+ * Flattens sequences of transforms and merges consecutive blocks
+ */
+ def normalize(trans: Transform): Transform = trans match {
+ case IfPat(thn, Some(els)) => IfPat(normalize(thn), Some(normalize(els)))
+ case IfPat(thn, None) => IfPat(normalize(thn), None)
+ case SeqPat(pats) =>
+ val rec = pats flatMap {pat => normalize(pat) match {
+ case SeqPat(pats) => pats;
+ case x => List(x)
+ }}
+ def noTwoBlocks: List[Transform] => List[Transform] = {
+ case BlockPat() :: BlockPat() :: l => noTwoBlocks(BlockPat() :: l)
+ case x :: l => x :: noTwoBlocks(l)
+ case Nil => Nil
+ }
+ SeqPat(noTwoBlocks(rec))
+ case _ => trans
+ }
+
+ sealed abstract class TransformMatch
+ case class Matched(ss: List[Statement]) extends TransformMatch {
+ def this(s: Statement) = this(List(s))
+ }
+ case class Unmatched(t: Transform) extends TransformMatch
+
+ /**
+ * Matches a proper block to a transform.
+ * Requires: a sequence pattern should not contain a sequence pattern
+ */
+ def refine:(List[Statement], Transform) => TransformMatch = {
+ // order is important!
+ // whole program
+ case (l, ProgramPat(code)) => new Matched(RefinementBlock(code, l))
+ // if pattern
+ case (List(IfStmt(guard, thn, None)), t @ IfPat(thnT, None)) =>
+ refine(thn.ss, thnT) match {
+ case Matched(thn0) => new Matched(IfStmt(guard, BlockStmt(thn0), None))
+ case _ => Unmatched(t)
+ }
+ case (List(IfStmt(guard, thn, Some(els))), t @ IfPat(thnT, Some(elsT))) =>
+ (refine(thn.ss, thnT), refine(List(els), elsT)) match {
+ case (Matched(thn0), Matched(els0)) => new Matched(IfStmt(guard, BlockStmt(thn0), Some(BlockStmt(els0))))
+ case _ => Unmatched(t)
+ }
+ // non det pat
+ case (l @ List(_: Call), NonDetPat(_, code)) => new Matched(RefinementBlock(code, l))
+ case (l @ List(_: SpecStmt), NonDetPat(_, code)) => new Matched(RefinementBlock(code, l))
+ // insert pat
+ case (Nil, InsertPat(code)) => new Matched(RefinementBlock(code, Nil))
+ // reduction of base cases
+ case (l, SeqPat(List(t))) => refine(l, t)
+ case (List(BlockStmt(ss)), t) => refine(ss, t)
+ // block pattern (greedy matching)
+ case (l, bp @ BlockPat()) if (l forall {s => bp matches s}) => Matched(l)
+ case (s :: ss, SeqPat((bp @ BlockPat()) :: ts)) if (bp matches s) =>
+ refine(ss, SeqPat(ts)) match {
+ case Matched(l) => Matched(s :: l)
+ case x => x
+ }
+ case (l, SeqPat((bp @ BlockPat()) :: ts)) if (l.size == 0 || !(bp matches l.head)) =>
+ refine(l, SeqPat(ts))
+ case (l, SkipPat()) => Matched(l)
+ // sequence pattern
+ case (s :: ss, SeqPat((np: NonDetPat) :: ts)) =>
+ (refine(List(s), np), refine(ss, SeqPat(ts))) match {
+ case (Matched(a), Matched(b)) => Matched(a ::: b)
+ case _ => Unmatched(np)
+ }
+ case (s :: ss, SeqPat((ip: IfPat) :: ts)) =>
+ (refine(List(s), ip), refine(ss, SeqPat(ts))) match {
+ case (Matched(a), Matched(b)) => Matched(a ::: b)
+ case _ => Unmatched(ip)
+ }
+ case (l, SeqPat(InsertPat(code) :: ts)) =>
+ refine(l, SeqPat(ts)) match {
+ case Matched(a) => Matched(RefinementBlock(code, Nil) :: a)
+ case x => x
+ }
+ case (_, t) => Unmatched(t)
+ }
+
+ /**
* Transforms an expression using f. f must produce expressions of the appropriate type (e.g. not replace int literal with a bool literal)
* Ensures that mutable fields of expressions are carried over. f must make sure that mutable fields of its value are filled in.
*/
|