summaryrefslogtreecommitdiff
path: root/Test/dafny4/BinarySearch.dfy
blob: b11fc0d1deffdd4f1f546eb99735e841e86ae61a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// RUN: %dafny /compile:0 /dprint:"%t.dprint" "%s" > "%t"
// RUN: %diff "%s.expect" "%t"

// Binary search using standard integers

method BinarySearch(a: array<int>, key: int) returns (r: int)
  requires a != null
  requires forall i,j :: 0 <= i < j < a.Length ==> a[i] <= a[j]
  ensures 0 <= r ==> r < a.Length && a[r] == key
  ensures r < 0 ==> key !in a[..]
{
  var lo, hi := 0, a.Length;
  while lo < hi
    invariant 0 <= lo <= hi <= a.Length
    invariant key !in a[..lo] && key !in a[hi..]
  {
    var mid := (lo + hi) / 2;
    if key < a[mid] {
      hi := mid;
    } else if a[mid] < key {
      lo := mid + 1;
    } else {
      return mid;
    }
  }
  return -1;
}

// Binary search using bounded integers

newtype int32 = x | -0x8000_0000 <= x < 0x8000_0000

method BinarySearchInt32_bad(a: array<int>, key: int) returns (r: int32)
  requires a != null && a.Length < 0x8000_0000
  requires forall i,j :: 0 <= i < j < a.Length ==> a[i] <= a[j]
  ensures 0 <= r ==> r < int32(a.Length) && a[r] == key
  ensures r < 0 ==> key !in a[..]
{
  var lo, hi := 0, int32(a.Length);
  while lo < hi
    invariant 0 <= lo <= hi <= int32(a.Length)
    invariant key !in a[..lo] && key !in a[hi..]
  {
    var mid := (lo + hi) / 2;  // error: possible overflow
    if key < a[mid] {
      hi := mid;
    } else if a[mid] < key {
      lo := mid + 1;
    } else {
      return mid;
    }
  }
  return -1;
}

method BinarySearchInt32_good(a: array<int>, key: int) returns (r: int32)
  requires a != null && a.Length < 0x8000_0000
  requires forall i,j :: 0 <= i < j < a.Length ==> a[i] <= a[j]
  ensures 0 <= r ==> r < int32(a.Length) && a[r] == key
  ensures r < 0 ==> key !in a[..]
{
  var lo, hi := 0, int32(a.Length);
  while lo < hi
    invariant 0 <= lo <= hi <= int32(a.Length)
    invariant key !in a[..lo] && key !in a[hi..]
  {
    var mid := lo + (hi - lo) / 2;  // this is how to do it
    if key < a[mid] {
      hi := mid;
    } else if a[mid] < key {
      lo := mid + 1;
    } else {
      return mid;
    }
  }
  return -1;
}