summaryrefslogtreecommitdiff
path: root/Jennisys/CodeGen.fs
blob: 5d7f86114edb2346495f392224f06aaac7cb9ea9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
module CodeGen

open Ast
open AstUtils
open Utils
open Printer   
open Resolver
open TypeChecker
open DafnyPrinter

let numLoopUnrolls = 2

let rec GetUnrolledFieldValidExpr fldExpr fldName validFunName numUnrolls : Expr = 
  if numUnrolls = 0 then
    TrueLiteral
  else
    BinaryImplies (BinaryNeq fldExpr (IdLiteral("null")))
                  (BinaryAnd (Dot(fldExpr, validFunName))
                             (GetUnrolledFieldValidExpr (Dot(fldExpr, fldName)) fldName validFunName (numUnrolls-1)))

let GetFieldValidExpr fldName validFunName numUnrolls : Expr = 
  GetUnrolledFieldValidExpr (IdLiteral(fldName)) fldName validFunName numUnrolls

let GetFieldsForValidExpr allFields prog : VarDecl list =
  allFields |> List.filter (function Var(name, tp) when IsUserType prog tp -> true
                                     | _                                   -> false)

let GetFieldsValidExprList clsName allFields prog : Expr list =
  let fields = GetFieldsForValidExpr allFields prog
  fields |> List.map (function Var(name, t) -> 
                                 let validFunName, numUnrolls = 
                                   match t with
                                   | Some(ty) when clsName = (GetTypeShortName ty) -> "Valid_self()", numLoopUnrolls
                                   | _ -> "Valid()", 1
                                 GetFieldValidExpr name validFunName numUnrolls
                     )

let PrintValidFunctionCode comp prog : string = 
  let idt = "    "
  let __PrintInvs invs = 
    invs |> List.fold (fun acc e -> List.concat [acc ; SplitIntoConjunts e]) []
         |> PrintSep (" &&" + newline) (fun e -> sprintf "%s(%s)" idt (PrintExpr 0 e))
         |> fun s -> if s = "" then (idt + "true") else s
  let clsName = GetClassName comp
  let vars = GetAllFields comp
  let allInvs = GetInvariantsAsList comp |> DesugarLst
  let fieldsValid = GetFieldsValidExprList clsName vars prog
                                                                
  // TODO: don't hardcode decr vars!!!
//  let decrVars = if List.choose (function Var(n,_) -> Some(n)) vars |> List.exists (fun n -> n = "next") then
//                   ["list"]
//                 else
//                   []
//  (if List.isEmpty decrVars then "" else sprintf "    decreases %s;%s" (PrintSep ", " (fun a -> a) decrVars) newline) +
  "  function Valid_self(): bool" + newline +
  "    reads *;" + newline +
  "  {" + newline + 
  (__PrintInvs allInvs) + newline +
  "  }" + newline +
  newline +
  "  function Valid(): bool" + newline +
  "    reads *;" + newline +
  "  {" + newline + 
  "    this.Valid_self() &&" + newline +
  (__PrintInvs fieldsValid) + newline +
  "  }" + newline

let PrintDafnyCodeSkeleton prog methodPrinterFunc: string =
  match prog with
  | Program(components) -> components |> List.fold (fun acc comp -> 
      match comp with  
      | Component(Class(name,typeParams,members), Model(_,_,cVars,frame,inv), code) as comp ->
        let aVars = FilterFieldMembers members
        let allVars = List.concat [aVars ; cVars];
        let compMethods = FilterConstructorMembers members
        // Now print it as a Dafny program
        acc + 
        (sprintf "class %s%s {" name (PrintTypeParams typeParams)) + newline +       
        // the fields: original abstract fields plus concrete fields
        (sprintf "%s" (PrintFields aVars 2 true)) + newline +     
        (sprintf "%s" (PrintFields cVars 2 false)) + newline +                           
        // generate the Valid function
        (sprintf "%s" (PrintValidFunctionCode comp prog)) + newline +
        // call the method printer function on all methods of this component
        (compMethods |> List.fold (fun acc m -> acc + (methodPrinterFunc comp m)) "") +
        // the end of the class
        "}" + newline + newline
      | _ -> assert false; "") ""
  
let PrintAllocNewObjects (heap,env,ctx) indent = 
  let idt = Indent indent
  env |> Map.fold (fun acc l v ->
                     match v with 
                     | NewObj(_,_) -> acc |> Set.add v
                     | _ -> acc
                  ) Set.empty
      |> Set.fold (fun acc newObjConst ->
                    match newObjConst with
                    | NewObj(name, Some(tp)) -> acc + (sprintf "%svar %s := new %s;%s" idt (PrintGenSym name) (PrintType tp) newline)
                    | _ -> failwithf "NewObj doesn't have a type: %O" newObjConst
                  ) ""

let PrintObjRefName o (env,ctx) = 
  match Resolve (env,ctx) o with
  | ThisConst(_,_) -> "this";
  | NewObj(name, _) -> PrintGenSym name
  | _ -> failwith ("unresolved object ref: " + o.ToString())

let CheckUnresolved c =
  match c with 
  | Unresolved(_) -> Logger.WarnLine "!!! There are some unresolved constants in the output file !!!"; c 
  | _ -> c

let PrintVarAssignments (heap,env,ctx) indent = 
  let idt = Indent indent
  heap |> Map.fold (fun acc (o,f) l ->
                      let objRef = PrintObjRefName o (env,ctx)
                      let fldName = PrintVarName f
                      let value = TryResolve (env,ctx) l |> CheckUnresolved |> PrintConst
                      acc + (sprintf "%s%s.%s := %s;" idt objRef fldName value) + newline
                   ) ""

let rec PrintHeapCreationCode sol indent =    
  let idt = Indent indent
  match sol with
  | (c, (heap,env,ctx)) :: rest ->
      if c = TrueLiteral then
        (PrintAllocNewObjects (heap,env,ctx) indent) +
        (PrintVarAssignments (heap,env,ctx) indent) +
        newline + 
        (PrintHeapCreationCode rest indent) 
      else
        if List.length rest > 0 then
          idt + "if (" + (PrintExpr 0 c) + ") {" + newline +
          (PrintAllocNewObjects (heap,env,ctx) (indent+2)) +
          (PrintVarAssignments (heap,env,ctx) (indent+2)) +
          idt + "} else {" + newline + 
          (PrintHeapCreationCode rest (indent+2)) +
          idt + "}" + newline
        else 
          (PrintAllocNewObjects (heap,env,ctx) indent) +
          (PrintVarAssignments (heap,env,ctx) indent)
  | [] -> ""

let GenConstructorCode mthd body =
  let validExpr = IdLiteral("Valid()");
  match mthd with
  | Method(methodName, sign, pre, post, _) -> 
      let __PrintPrePost pfix expr = SplitIntoConjunts expr |> PrintSep newline (fun e -> pfix + (PrintExpr 0 e) + ";")
      let preExpr = pre 
      let postExpr = BinaryAnd validExpr post
      "  method " + methodName + (PrintSig sign) + newline +
      "    modifies this;" + newline +
      (__PrintPrePost "    requires " preExpr) + newline +
      (__PrintPrePost "    ensures " postExpr) + newline +
      "  {" + newline + 
      body + 
      "  }" + newline
  | _ -> ""

// solutions: (comp, constructor) |--> (heap, env, ctx) 
let PrintImplCode prog solutions methodsToPrintFunc =
  let methods = methodsToPrintFunc prog
  PrintDafnyCodeSkeleton prog (fun comp mthd ->
                                 if Utils.ListContains (comp,mthd) methods  then
                                   let mthdBody = match Map.tryFind (comp,mthd) solutions with
                                                  | Some(sol) -> PrintHeapCreationCode sol 4
                                                  | _ -> "    //unable to synthesize" + newline
                                   (GenConstructorCode mthd mthdBody) + newline
                                 else
                                   "")