diff options
Diffstat (limited to 'Chalice/src/Ast.scala')
-rw-r--r-- | Chalice/src/Ast.scala | 237 |
1 files changed, 197 insertions, 40 deletions
diff --git a/Chalice/src/Ast.scala b/Chalice/src/Ast.scala index 25bfb975..0414285d 100644 --- a/Chalice/src/Ast.scala +++ b/Chalice/src/Ast.scala @@ -4,6 +4,7 @@ //
//-----------------------------------------------------------------------------
import scala.util.parsing.input.Position
+import scala.util.parsing.input.NoPosition
import scala.util.parsing.input.Positional
trait ASTNode extends Positional
@@ -124,10 +125,8 @@ sealed abstract class NamedMember(id: String) extends Member { var Parent: Class = null
def FullName = Parent.id + "." + Id
}
-case class Field(id: String, typ: Type) extends NamedMember(id) {
- val IsGhost: Boolean = false
-}
-case class SpecialField(name: String, tp: Type) extends Field(name, tp) { // direct assignments are not allowed to a SpecialField
+case class Field(id: String, typ: Type, isGhost: Boolean) extends NamedMember(id)
+case class SpecialField(name: String, tp: Type) extends Field(name, tp, false) { // direct assignments are not allowed to a SpecialField
override def FullName = id
}
case class Method(id: String, ins: List[Variable], outs: List[Variable], spec: List[Specification], body: List[Statement]) extends NamedMember(id)
@@ -146,12 +145,12 @@ class Variable(name: String, typ: Type) extends ASTNode { val IsGhost: Boolean = false
val IsImmutable: Boolean = false
val UniqueName = {
- val n = S_Variable.VariableCount
- S_Variable.VariableCount = S_Variable.VariableCount + 1
+ val n = Variable.VariableCount
+ Variable.VariableCount = Variable.VariableCount + 1
name + "#" + n
}
}
-object S_Variable { var VariableCount = 0 }
+object Variable { var VariableCount = 0 }
class ImmutableVariable(id: String, t: Type) extends Variable(id, t) {
override val IsImmutable: Boolean = true
}
@@ -234,7 +233,9 @@ case class NewRhs(id: String, initialization: List[Init], lowerBounds: List[Expr case class Init(id: String, e: Expression) extends ASTNode {
var f: Field = null;
}
-sealed abstract class Expression extends RValue
+sealed abstract class Expression extends RValue {
+ def transform(f: Expression => Option[Expression]) = AST.transform(this, f)
+}
sealed abstract class Literal extends Expression
case class IntLiteral(n: Int) extends Literal
case class BoolLiteral(b: Boolean) extends Literal
@@ -273,14 +274,12 @@ case class Epsilons(n: Expression) extends Read // Some(Some(n)) sealed abstract class PermissionExpr(perm: Permission) extends Expression
sealed abstract class WildCardPermission(perm: Permission) extends PermissionExpr(perm)
-case class Access(e: MemberAccess, perm: Permission) extends PermissionExpr(perm) {
- def getMemberAccess = e : MemberAccess;
-}
+case class Access(ma: MemberAccess, perm: Permission) extends PermissionExpr(perm)
case class AccessAll(obj: Expression, perm: Permission) extends WildCardPermission(perm)
case class AccessSeq(s: Expression, f: Option[MemberAccess], perm: Permission) extends WildCardPermission(perm)
case class Credit(e: Expression, n: Option[Expression]) extends Expression {
- def N = n match { case None => IntLiteral(1) case Some(n) => n }
+ val N = n match { case None => IntLiteral(1) case Some(n) => n }
}
case class Holds(e: Expression) extends Expression
@@ -295,10 +294,10 @@ case class FunctionApplication(obj: Expression, id: String, args: List[Expressio }
case class Unfolding(pred: Access, in: Expression) extends Expression
sealed abstract class BinaryExpr(e0: Expression, e1: Expression) extends Expression {
- def E0 = e0
- def E1 = e1
- def ExpectedLhsType: Class = BoolClass // sometimes undefined
- def ExpectedRhsType: Class = BoolClass // sometimes undefined
+ val E0 = e0
+ val E1 = e1
+ val ExpectedLhsType: Class = BoolClass // sometimes undefined
+ val ExpectedRhsType: Class = BoolClass // sometimes undefined
val ResultType: Class = BoolClass
val OpName: String
}
@@ -315,8 +314,8 @@ case class Or(e0: Expression, e1: Expression) extends BinaryExpr(e0,e1) { override val OpName = "||"
}
sealed abstract class ArithmeticExpr(e0: Expression, e1: Expression) extends BinaryExpr(e0,e1) {
- override def ExpectedLhsType = IntClass
- override def ExpectedRhsType = IntClass
+ override val ExpectedLhsType = IntClass
+ override val ExpectedRhsType = IntClass
override val ResultType = IntClass
}
case class Plus(e0: Expression, e1: Expression) extends ArithmeticExpr(e0,e1) {
@@ -335,12 +334,12 @@ case class Mod(e0: Expression, e1: Expression) extends ArithmeticExpr(e0,e1) { override val OpName = "%"
}
sealed abstract class CompareExpr(e0: Expression, e1: Expression) extends BinaryExpr(e0,e1) {
- override def ExpectedLhsType = IntClass
- override def ExpectedRhsType = IntClass
+ override val ExpectedLhsType = IntClass
+ override val ExpectedRhsType = IntClass
}
sealed abstract class EqualityCompareExpr(e0: Expression, e1: Expression) extends CompareExpr(e0,e1) {
- override def ExpectedLhsType = throw new Exception("EqualityCompareExpr does not have a single ExpectedArgsType")
- override def ExpectedRhsType = throw new Exception("EqualityCompareExpr does not have a single ExpectedArgsType")
+ override val ExpectedLhsType = null;
+ override val ExpectedRhsType = null;
}
case class Eq(e0: Expression, e1: Expression) extends EqualityCompareExpr(e0,e1) {
override val OpName = "=="
@@ -361,23 +360,24 @@ case class Greater(e0: Expression, e1: Expression) extends CompareExpr(e0,e1) { override val OpName = ">"
}
case class LockBelow(e0: Expression, e1: Expression) extends CompareExpr(e0,e1) {
- override def ExpectedLhsType = throw new Exception("LockBelow does not have a single ExpectedArgsType")
- override def ExpectedRhsType = throw new Exception("LockBelow does not have a single ExpectedArgsType")
+ override val ExpectedLhsType = null;
+ override val ExpectedRhsType = null;
override val OpName = "<<"
}
-sealed abstract class Quantification(is: List[String], seq: Expression, e: Expression) extends Expression {
- def Quantor: String;
- def Is = is
- def Seq = seq
- def E = e
- var variables = null: List[Variable];
-}
-case class Forall(is: List[String], seq: Expression, e: Expression) extends Quantification(is, seq, e) {
- override def Quantor = "forall"
-}
-case class Exists(is: List[String], seq: Expression, e: Expression) extends Quantification(is, seq, e) {
- override def Quantor = "exists"
-}
+
+// quantifiers
+trait Quant
+object Forall extends Quant
+object Exists extends Quant
+
+sealed abstract class Quantification(q: Quant, is: List[String], e: Expression) extends Expression {
+ val Q = q;
+ val Is = is;
+ val E = e;
+ var variables = null: List[Variable]; // resolved by type checker
+}
+case class SeqQuantification(q: Quant, is: List[String], seq: Expression, e: Expression) extends Quantification(q, is, e)
+case class TypeQuantification(q: Quant, is: List[String], t: Type, e: Expression) extends Quantification(q, is, e)
// sequences
@@ -387,7 +387,11 @@ case class Range(min: Expression, max: Expression /* non-inclusive*/) extends Ex case class Append(s0: Expression, s1: Expression) extends SeqAccess(s0, s1) {
override val OpName = "++"
}
-sealed abstract case class SeqAccess(e0: Expression, e1: Expression) extends BinaryExpr(e0, e1)
+sealed abstract case class SeqAccess(e0: Expression, e1: Expression) extends BinaryExpr(e0, e1) {
+ override val ExpectedLhsType = null
+ override val ExpectedRhsType = null
+ override val ResultType = null
+}
case class Length(e: Expression) extends Expression
case class At(s: Expression, n: Expression) extends SeqAccess(s, n) {
override val OpName = ""
@@ -398,11 +402,10 @@ case class Drop(s: Expression, n: Expression) extends SeqAccess(s, n) { case class Take(s: Expression, n: Expression) extends SeqAccess(s, n) {
override val OpName = ""
}
-case class Contains(s: Expression, n: Expression) extends SeqAccess(s, n) {
+case class Contains(n: Expression, s: Expression) extends SeqAccess(n, s) {
override val OpName = "in"
}
-
// eval
case class Eval(h: EvalState, e: Expression) extends Expression
@@ -419,3 +422,157 @@ case class CallState(token: Expression, obj: Expression, id: String, args: List[ var m = null: Method;
def target() = token;
}
+
+// visitors / operations
+
+object AST {
+ /**
+ * Transforms an expression using f. f must produce expressions of the appropriate type (e.g. not replace int literal with a bool literal)
+ * Ensures that mutable fields of expressions are carried over. f must make sure that mutable fields of its value are filled in.
+ */
+ def transform(expr: Expression, f: Expression => Option[Expression]):Expression = {
+ val func = (e:Expression) => transform(e, f);
+ val x = f(expr);
+ // apply recursively
+ val result = if (x isDefined) x.get else expr match {
+ case _:Literal => expr
+ case _:ThisExpr => expr
+ case _:Result => expr
+ case _:VariableExpr => expr
+ case ma@MemberAccess(e, id) =>
+ val g = MemberAccess(func(e), id);
+ g.f = ma.f;
+ g.predicate = ma.predicate;
+ g.isPredicate = ma.isPredicate;
+ g
+ case Full | Epsilon | Star => expr
+ case Frac(perm) => Frac(func(perm))
+ case Epsilons(perm) => Epsilons(func(perm))
+ case Access(e, perm) => Access(func(e).asInstanceOf[MemberAccess], func(perm).asInstanceOf[Permission]);
+ case AccessAll(obj, perm) => AccessAll(func(obj), func(perm).asInstanceOf[Permission]);
+ case AccessSeq(s, None, perm) => AccessSeq(func(s), None, func(perm).asInstanceOf[Permission])
+ case AccessSeq(s, Some(f), perm) => AccessSeq(func(s), Some(func(f).asInstanceOf[MemberAccess]), func(perm).asInstanceOf[Permission])
+ case Credit(e, None) => Credit(func(e), None)
+ case Credit(e, Some(n)) => Credit(func(e), Some(func(n)))
+ case Holds(e) => Holds(func(e))
+ case RdHolds(e) => RdHolds(func(e))
+ case _: Assigned => expr
+ case Old(e) => Old(func(e))
+ case IfThenElse(con, then, els) => IfThenElse(func(con), func(then), func(els))
+ case Not(e) => Not(func(e))
+ case funapp@FunctionApplication(obj, id, args) =>
+ val appl = FunctionApplication(func(obj), id, args map { arg => func(arg)});
+ appl.f = funapp.f;
+ appl
+ case Unfolding(pred, e) =>
+ Unfolding(func(pred).asInstanceOf[Access], func(e))
+ case Iff(e0,e1) => Iff(func(e0), func(e1))
+ case Implies(e0,e1) => Implies(func(e0), func(e1))
+ case And(e0,e1) => And(func(e0), func(e1))
+ case Or(e0,e1) => Or(func(e0), func(e1))
+ case Eq(e0,e1) => Eq(func(e0), func(e1))
+ case Neq(e0,e1) => Neq(func(e0), func(e1))
+ case Less(e0,e1) => Less(func(e0), func(e1))
+ case AtMost(e0,e1) => AtMost(func(e0), func(e1))
+ case AtLeast(e0,e1) => AtLeast(func(e0), func(e1))
+ case Greater(e0,e1) => Greater(func(e0), func(e1))
+ case LockBelow(e0,e1) => LockBelow(func(e0), func(e1))
+ case Plus(e0,e1) => Plus(func(e0), func(e1))
+ case Minus(e0,e1) => Minus(func(e0), func(e1))
+ case Times(e0,e1) => Times(func(e0), func(e1))
+ case Div(e0,e1) => Div(func(e0), func(e1))
+ case Mod(e0,e1) => Mod(func(e0), func(e1))
+ case ExplicitSeq(es) => ExplicitSeq(es map { e => func(e) })
+ case Range(min, max)=> Range(func(min), func(max))
+ case Append(e0, e1) => Append(func(e0), func(e1))
+ case At(e0, e1) => At(func(e0), func(e1))
+ case Drop(e0, e1) => Drop(func(e0), func(e1))
+ case Take(e0, e1) => Take(func(e0), func(e1))
+ case Length(e) => Length(func(e))
+ case Contains(e0, e1) => Contains(func(e0), func(e1))
+ case qe @ SeqQuantification(q, is, seq, e) =>
+ val result = SeqQuantification(q, is, func(seq), func(e));
+ result.variables = qe.variables;
+ result;
+ case qe @ TypeQuantification(q, is, t, e) =>
+ val result = TypeQuantification(q, is, t, func(e));
+ result.variables = qe.variables;
+ result;
+ case Eval(h, e) =>
+ Eval(h match {
+ case AcquireState(obj) => AcquireState(func(obj))
+ case ReleaseState(obj) => ReleaseState(func(obj))
+ case cs @ CallState(token, obj, i, args) =>
+ val result = CallState(func(token), func(obj), i, args map { a => func(a)});
+ result.m = cs.m;
+ result;
+ }, func(e))
+ };
+
+ // preserve type
+ if (result.typ == null) result.typ = expr.typ;
+ // preserve position
+ if (result.pos == NoPosition) result.pos = expr.pos
+ result
+ }
+
+ // Applies recursively the function f first to the expression and then to its subexpressions (that is members of type RValue)
+ def visit(expr: RValue, f: RValue => Unit) {
+ f(expr);
+ expr match {
+ case _:Literal => ;
+ case _:ThisExpr => ;
+ case _:Result => ;
+ case _:VariableExpr => ;
+ case MemberAccess(e, _) =>
+ visit(e, f);
+
+ case Frac(p) => visit(p, f);
+ case Epsilons(p) => visit(p, f);
+ case Full | Epsilon | Star =>;
+ case Access(e, perm) =>
+ visit(e, f); visit(perm, f);
+ case AccessAll(obj, perm) =>
+ visit(obj, f); visit(perm, f);
+ case AccessSeq(s, _, perm) =>
+ visit(s, f); visit(perm, f);
+
+ case Credit(e, n) =>
+ visit(e, f); n match { case Some(n) => visit(n, f); case _ => }
+ case Holds(e) => visit(e, f);
+ case RdHolds(e) => visit(e, f);
+
+ case e: BinaryExpr =>
+ visit(e.E0, f); visit(e.E1, f);
+ case Range(min, max) =>
+ visit(min, f); visit(max, f);
+ case e: Assigned => e
+ case Old(e) => visit(e, f);
+ case IfThenElse(con, then, els) => visit(con, f); visit(then, f); visit(els, f);
+ case Not(e) => visit(e, f);
+ case funapp@FunctionApplication(obj, id, args) =>
+ visit(obj, f); args foreach { arg => visit(arg, f) };
+ case Unfolding(pred, e) =>
+ visit(pred, f); visit(e, f);
+
+ case SeqQuantification(_, _, seq, e) => visit(seq, f); visit(e, f);
+ case TypeQuantification(_, _, _, e) => visit(e, f);
+
+ case ExplicitSeq(es) =>
+ es foreach { e => visit(e, f) }
+ case Length(e) =>
+ visit(e, f)
+ case Eval(h, e) =>
+ h match {
+ case AcquireState(obj) => visit(obj, f);
+ case ReleaseState(obj) => visit(obj, f);
+ case CallState(token, obj, id, args) =>
+ visit(token, f); visit(obj, f); args foreach {a : Expression => visit(a, f)};
+ }
+ visit(e, f);
+ case NewRhs(_, init, lowerBounds, upperBounds) =>
+ lowerBounds foreach { e => visit(e, f)};
+ upperBounds foreach { e => visit(e, f)};
+ }
+ }
+}
\ No newline at end of file |