aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-04-14 18:14:39 -0400
committerGravatar Jason Gross <jgross@mit.edu>2017-04-14 18:14:39 -0400
commite3d27bfc58b601d4c3f344670ce7d04597ac5e61 (patch)
tree2a6a3c5f974bd5ac5f9efe4f4e5fb4605d84ec46 /src
parent8dc2c5c001b9b0e63ecf8324969b603694486d8b (diff)
Add for-loop combinator
Diffstat (limited to 'src')
-rw-r--r--src/Util/ForLoop.v67
-rw-r--r--src/Util/ForLoop/Tests.v7
-rw-r--r--src/Util/Notations.v10
3 files changed, 84 insertions, 0 deletions
diff --git a/src/Util/ForLoop.v b/src/Util/ForLoop.v
new file mode 100644
index 000000000..caa853a9a
--- /dev/null
+++ b/src/Util/ForLoop.v
@@ -0,0 +1,67 @@
+(** * Definition and Notations for [for (int i = i₀; i < i∞; i += Δi)] *)
+Require Import Coq.ZArith.BinInt.
+Require Import Crypto.Util.Notations.
+
+Section with_body.
+ Context {stateT : Type}
+ (body : nat -> stateT -> stateT).
+
+ Fixpoint repeat_function (count : nat) (st : stateT) : stateT
+ := match count with
+ | O => st
+ | S count' => repeat_function count' (body count' st)
+ end.
+End with_body.
+
+Local Open Scope bool_scope.
+Local Open Scope Z_scope.
+
+Definition for_loop (i0 finish : Z) (step : Z) {stateT} (initial : stateT) (body : Z -> stateT -> stateT)
+ : stateT
+ := let signed_step := (finish - i0) / step in
+ let count := Z.to_nat ((finish - i0) / signed_step) in
+ repeat_function (fun c => body (i0 + signed_step * Z.of_nat (count - c))) count initial.
+
+
+Notation "'for' i (:= i0 ; += step ; < finish ) 'updating' ( state := initial ) {{ body }}"
+ := (for_loop i0 finish step initial (fun i state => body))
+ : core_scope.
+
+Delimit Scope for_notation_scope with for_notation.
+Notation "x += y" := (x = Z.pos y) : for_notation_scope.
+Notation "x -= y" := (x = Z.neg y) : for_notation_scope.
+Notation "++ x" := (x += 1)%for_notation : for_notation_scope.
+Notation "-- x" := (x -= 1)%for_notation : for_notation_scope.
+Notation "x ++" := (x += 1)%for_notation : for_notation_scope.
+Notation "x --" := (x -= 1)%for_notation : for_notation_scope.
+Infix "<" := Z.ltb : for_notation_scope.
+Infix ">" := Z.gtb : for_notation_scope.
+Infix "<=" := Z.leb : for_notation_scope.
+Infix ">=" := Z.geb : for_notation_scope.
+
+Class class_eq {A} (x y : A) := make_class_eq : x = y.
+Global Instance class_eq_refl {A x} : @class_eq A x x := eq_refl.
+
+Class for_loop_is_good (i0 : Z) (step : Z) (finish : Z) (cmp : Z -> Z -> bool)
+ := make_good :
+ ((Z.sgn step =? Z.sgn (finish - i0))
+ && (cmp i0 finish))
+ = true.
+Hint Extern 0 (for_loop_is_good _ _ _ _) => vm_compute; reflexivity : typeclass_instances.
+
+Definition for_loop_notation {i0 : Z} {step : Z} {finish : Z} {stateT} {initial : stateT}
+ {cmp : Z -> Z -> bool}
+ step_expr finish_expr (body : Z -> stateT -> stateT)
+ {Hstep : class_eq (fun i => i = step) step_expr}
+ {Hfinish : class_eq (fun i => cmp i finish) finish_expr}
+ {Hgood : for_loop_is_good i0 step finish cmp}
+ : stateT
+ := for_loop i0 finish step initial body.
+
+Notation "'for' ( 'int' i = i0 ; finish_expr ; step_expr ) 'updating' ( state1 .. staten = initial ) {{ body }}"
+ := (@for_loop_notation
+ i0%Z _ _ _ initial%Z _
+ (fun i : Z => step_expr%for_notation)
+ (fun i : Z => finish_expr%for_notation)
+ (fun (i : Z) => (fun state1 => .. (fun staten => body) .. ))
+ eq_refl eq_refl _).
diff --git a/src/Util/ForLoop/Tests.v b/src/Util/ForLoop/Tests.v
new file mode 100644
index 000000000..7b800ddbc
--- /dev/null
+++ b/src/Util/ForLoop/Tests.v
@@ -0,0 +1,7 @@
+Require Import Coq.ZArith.BinInt.
+Require Import Crypto.Util.ForLoop.
+
+Local Open Scope Z_scope.
+
+Check (for i (:= 0; += 1; < 10) updating (v := 5) {{ v + i }}).
+Check (for (int i = 0; i < 5; i++) updating ( '(v1, v2) = (0, 0) ) {{ (v1 + i, v2 + i) }}).
diff --git a/src/Util/Notations.v b/src/Util/Notations.v
index 962117925..7570533ea 100644
--- a/src/Util/Notations.v
+++ b/src/Util/Notations.v
@@ -84,3 +84,13 @@ Reserved Notation "x ::> ( max_bitwidth = v )"
(at level 70, no associativity, format "x ::> ( max_bitwidth = v )").
Reserved Notation "r[ l ~> u ]" (format "r[ l ~> u ]").
Reserved Notation "b[ l ~> u ]" (format "b[ l ~> u ]").
+Reserved Notation "'for' i (:= i0 ; += step ; < finish ) 'updating' ( state := initial ) {{ body }}"
+ (at level 70, format "'[v ' 'for' i (:= i0 ; += step ; < finish ) 'updating' ( state := initial ) {{ '//' body ']' '//' }}").
+Reserved Notation "'for' ( 'int' i = i0 ; step_expr ; finish_expr ) 'updating' ( state1 .. staten = initial ) {{ body }}"
+ (at level 70, i at level 10, state1 binder, staten binder, format "'[v ' 'for' ( 'int' i = i0 ; step_expr ; finish_expr ) 'updating' ( state1 .. staten = initial ) {{ '//' body ']' '//' }}").
+Reserved Notation "x += y" (at level 70, no associativity).
+Reserved Notation "x -= y" (at level 70, no associativity).
+Reserved Notation "x ++" (at level 60, format "x ++").
+Reserved Notation "x --" (at level 60, format "x --").
+Reserved Notation "++ x" (at level 60, format "++ x").
+Reserved Notation "-- x" (at level 60, format "-- x").