summaryrefslogtreecommitdiff
path: root/Source/Core/LoopUnroll.ssc
blob: 4335a834b7f3051208c2dae39b5947324cdab7d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
//-----------------------------------------------------------------------------
//
// Copyright (C) Microsoft Corporation.  All Rights Reserved.
//
//-----------------------------------------------------------------------------
using Microsoft.Contracts;
using System.Collections.Generic;
using Cci = System.Compiler;
using Bpl = Microsoft.Boogie;

namespace Microsoft.Boogie
{
  public class LoopUnroll {
    public static List<Block!>! UnrollLoops(Block! start, int unrollMaxDepth)
      requires 0 <= unrollMaxDepth;
    {
      Dictionary<Block,GraphNode!> gd = new Dictionary<Block,GraphNode!>();
      Cci.HashSet/*Block*/! beingVisited = new Cci.HashSet/*Block*/();
      GraphNode gStart = GraphNode.ComputeGraphInfo(null, start, gd, beingVisited);
      
      LoopUnroll lu = new LoopUnroll(gd, unrollMaxDepth, null);
      lu.Visit(gStart);
      lu.newBlockSeqGlobal.Reverse();
      return lu.newBlockSeqGlobal;
    }
    
    class GraphNode {
      public readonly Block! Block;
      public readonly CmdSeq! Body;
      bool isCutPoint;  // is set during ComputeGraphInfo
      public bool IsCutPoint { get { return isCutPoint; } }
      [Rep] public readonly List<GraphNode!>! ForwardEdges = new List<GraphNode!>();
      [Rep] public readonly List<GraphNode!>! BackEdges = new List<GraphNode!>();
      invariant isCutPoint <==> BackEdges.Count != 0;
      
      GraphNode(Block! b, CmdSeq! body) {
        this.Block = b;
        this.Body = body;
      }
      
      static CmdSeq! GetOptimizedBody(CmdSeq! cmds) {
        int n = 0;
        foreach (Cmd c in cmds) {
          n++;
          PredicateCmd pc = c as PredicateCmd;
          if (pc != null && pc.Expr is LiteralExpr && ((LiteralExpr)pc.Expr).IsFalse) {
            // return a sequence consisting of the commands seen so far
            Cmd[] s = new Cmd[n];
            for (int i = 0; i < n; i++) {
              s[i] = cmds[i];
            }
            return new CmdSeq(s);
          }
        }
        return cmds;
      }

      public static GraphNode! ComputeGraphInfo(GraphNode from, Block! b, Dictionary<Block,GraphNode!>! gd, Cci.HashSet/*Block*/! beingVisited) {
        GraphNode g;
        if (gd.TryGetValue(b, out g)) {
          assume from != null;
          assert g != null;
          if (beingVisited.Contains(b)) {
            // it's a cut point
            g.isCutPoint = true;
            from.BackEdges.Add(g);
          } else {
            from.ForwardEdges.Add(g);
          }
          
        } else {
          CmdSeq body = GetOptimizedBody(b.Cmds);
          g = new GraphNode(b, body);
          gd.Add(b, g);
          if (from != null) {
            from.ForwardEdges.Add(g);
          }
          
          if (body != b.Cmds) {
            // the body was optimized -- there is no way through this block
          } else {
            beingVisited.Add(b);
            
            GotoCmd gcmd = b.TransferCmd as GotoCmd;
            if (gcmd != null) {
              assume gcmd.labelTargets != null;
              foreach (Block! succ in gcmd.labelTargets) {
                ComputeGraphInfo(g, succ, gd, beingVisited);
              }
            }
            
            beingVisited.Remove(b);
          }
        }
        return g;
      }
    }
    
    List<Block!>! newBlockSeqGlobal;
    readonly int c;
    readonly LoopUnroll next;
    Dictionary<Block,int>! visitsRemaining = new Dictionary</*cut-point-*/Block,int>();
    Dictionary<Block,Block!>! newBlocks = new Dictionary<Block,Block!>();
    
    private LoopUnroll(Dictionary<Block,GraphNode!>! gd, int unrollMaxDepth, List<Block!> newBlockSeqGlobal)
      requires 0 <= unrollMaxDepth;
    {
      if (newBlockSeqGlobal == null) {
        newBlockSeqGlobal = new List<Block!>();
      }
      this.newBlockSeqGlobal = newBlockSeqGlobal;
      this.c = unrollMaxDepth;
      if (unrollMaxDepth != 0) {
        next = new LoopUnroll(gd, unrollMaxDepth - 1, newBlockSeqGlobal);
      }
    }
    
    Block! Visit(GraphNode! node) {
      Block orig = node.Block;
      Block nw;
      if (newBlocks.TryGetValue(orig, out nw)) {
        assert nw != null;
        
      } else {
        CmdSeq body;
        TransferCmd tcmd;
        assert orig.TransferCmd != null;
          
        if (next == null && node.IsCutPoint) {
          // as the body, use the assert/assume commands that make up the loop invariant
          body = new CmdSeq();
          foreach (Cmd! c in node.Body) {
            if (c is PredicateCmd || c is CommentCmd) {
              body.Add(c);
            } else {
              break;
            }
          }
          body.Add(new AssumeCmd(orig.tok, Bpl.Expr.False));
        
          tcmd = new ReturnCmd(orig.TransferCmd.tok);

        } else {
          body = node.Body;
          BlockSeq newSuccs = new BlockSeq();
        
          foreach (GraphNode succ in node.ForwardEdges) {
            Block s = Visit(succ);
            newSuccs.Add(s);
          }
          
          assert next == null ==> node.BackEdges.Count == 0;  // follows from if-else test above and the GraphNode invariant
          foreach (GraphNode succ in node.BackEdges) {
            assert next != null;  // since if we get here, node.BackEdges.Count != 0
            Block s = next.Visit(succ);
            newSuccs.Add(s);
          }
          
          if (newSuccs.Length == 0) {
            tcmd = new ReturnCmd(orig.TransferCmd.tok);
          } else {
            tcmd = new GotoCmd(orig.TransferCmd.tok, newSuccs);
          }
        }

        nw = new Block(orig.tok, orig.Label + "#" + this.c, body, tcmd);
        newBlocks.Add(orig, nw);
        newBlockSeqGlobal.Add(nw);
      }
      
      return nw;
    }
  }
}