aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-03-06 15:12:08 +0100
committerGravatar jadephilipoom <jade.philipoom@gmail.com>2018-04-03 09:00:55 -0400
commitee89ae1f28e6e23917c7cd1bd68bdba1063334c1 (patch)
tree4fe38ead6974089a3f6bcabedf68ec7411809b7f /src
parentd809687fc691725f0347ceb4d76b35c373965980 (diff)
Make Montgomery example use row-wise flatten (involves adding Nat.max, List.tl, and List.hd to the pipeline)
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v179
1 files changed, 116 insertions, 63 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 4d0af6491..96e54bebb 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -987,9 +987,9 @@ Module Rows.
Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed.
Hint Rewrite eval_app : push_eval.
- Definition extract_row (inp : cols) : cols * list Z := (map (@tl _) inp, map (hd 0) inp).
+ Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp).
- Definition max_column_size (x:cols) := fold_right Nat.max 0%nat (map (@length Z) x).
+ Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x).
Definition from_columns' n start_state : cols * rows :=
fold_right (fun _ (state : cols * rows) =>
@@ -1158,10 +1158,11 @@ Module Rows.
Local Notation fw := (fun i => weight (S i) / weight i) (only parsing).
Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z :=
- fold_left (fun (state : list Z * Z) next =>
+ fold_right (fun next (state : list Z * Z) =>
let i := length (fst state) in (* length of output accumulator tells us the index of [next] *)
- let sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in
- (fst state ++ [fst sum_carry], snd sum_carry)) (combine row1 row2) start_state.
+ dlet_nd next := next in (* makes the output correctly bind variables *)
+ dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in
+ (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)).
Definition sum_rows := sum_rows' (nil,0).
Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z :=
@@ -1189,8 +1190,8 @@ Module Rows.
/\ snd (sum_rows' start_state row1 row2)
= (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm).
Proof.
- cbv [sum_rows'].
- induction row1 as [|x1 row1]; intros;
+ cbv [sum_rows' Let_In].
+ induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
repeat match goal with
| _ => progress autorewrite with natsimplify list
@@ -1201,6 +1202,7 @@ Module Rows.
specialize (IHrow1 (pred n) (S m)).
replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ rewrite <-fold_left_rev_right.
apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
repeat match goal with
| H : ?LHS = _ |- _ =>
@@ -1266,8 +1268,8 @@ Module Rows.
nth_default 0 (fst (sum_rows' start_state row1 row2)) i
= ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i).
Proof.
- cbv [sum_rows'].
- induction row1 as [|x1 row1]; intros;
+ cbv [sum_rows' Let_In].
+ induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *;
destruct row2 as [|x2 row2]; distr_length; [ subst n | ];
repeat match goal with
| _ => progress cbn [fold_left]
@@ -1279,6 +1281,7 @@ Module Rows.
specialize (IHrow1 (pred n) (S m)).
replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega.
rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2').
+ rewrite <-fold_left_rev_right.
apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega;
repeat match goal with
| _ => progress intros
@@ -1510,10 +1513,10 @@ Module MulConverted.
let p3_a := Associational.mul p1_a p2_a in
(* important not to use Positional.carry here; we don't want to accumulate yet *)
let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in
- fst (Columns.flatten w (Columns.from_associational w n3 p3'_a)).
+ fst (Rows.flatten w (Rows.from_associational w n3 p3'_a)).
Hint Rewrite
- @Columns.eval_from_associational
+ @Rows.eval_from_associational
@Associational.eval_carry
@Associational.eval_mul
@Positional.eval_to_associational
@@ -1531,20 +1534,18 @@ Module MulConverted.
Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2).
Proof.
cbv [mul_converted]; intros.
- rewrite Columns.flatten_mod by auto using Columns.length_from_associational.
+ rewrite Rows.flatten_mod by eauto using Rows.length_from_associational.
autorewrite with push_eval. auto using Z.mod_small.
Qed.
Hint Rewrite eval_mul_converted : push_eval.
- Hint Rewrite @Columns.length_from_associational : distr_length.
-
Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat):
length p1 = n1 -> length p2 = n2 ->
0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 ->
nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1).
Proof.
intros; cbv [mul_converted].
- erewrite Columns.flatten_partitions by (auto; distr_length).
+ rewrite Rows.flatten_partitions with (n:=n3) by (eauto using Rows.length_from_associational; omega).
autorewrite with distr_length push_eval natsimplify.
rewrite w_0; autorewrite with zsimplify.
reflexivity.
@@ -1559,7 +1560,7 @@ Module MulConverted.
nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1).
Proof.
intros; subst n3; cbv [mul_converted].
- erewrite Columns.flatten_partitions by (auto; distr_length).
+ rewrite Rows.flatten_partitions with (n:=2%nat) by (eauto using Rows.length_from_associational; omega).
autorewrite with distr_length push_eval.
rewrite Z.mod_small; omega.
Qed.
@@ -2245,6 +2246,7 @@ Module Compilers.
| primitive {t:type.primitive} (v : interp t) : ident () t
| Let_In {tx tC} : ident (tx * (tx -> tC)) tC
| Nat_succ : ident nat nat
+ | Nat_max : ident (nat * nat) nat
| Nat_mul : ident (nat * nat) nat
| Nat_add : ident (nat * nat) nat
| nil {t} : ident () (list t)
@@ -2265,6 +2267,8 @@ Module Compilers.
| List_partition {A} : ident ((A -> bool) * list A) (list A * list A)
| List_app {A} : ident (list A * list A) (list A)
| List_rev {A} : ident (list A) (list A)
+ | List_tl {A} : ident (list A) (list A)
+ | List_hd {A} : ident (A * list A) A
| List_fold_right {A B} : ident ((B * A -> A) * A * list B) A
| List_update_nth {T} : ident (nat * (T -> T) * list T) (list T)
| List_nth_default {T} : ident (T * list T * nat) T
@@ -2313,6 +2317,7 @@ Module Compilers.
| Nat_succ => Nat.succ
| Nat_add => curry2 Nat.add
| Nat_mul => curry2 Nat.mul
+ | Nat_max => curry2 Nat.max
| nil t => curry0 (@Datatypes.nil (type.interp t))
| cons t => curry2 (@Datatypes.cons (type.interp t))
| fst A B => @Datatypes.fst (type.interp A) (type.interp B)
@@ -2331,6 +2336,8 @@ Module Compilers.
| List_partition A => curry2 (@List.partition (type.interp A))
| List_app A => curry2 (@List.app (type.interp A))
| List_rev A => @List.rev (type.interp A)
+ | List_tl A => @List.tl (type.interp A)
+ | List_hd A => curry2 (@List.hd (type.interp A))
| List_fold_right A B => curry3_1 (@List.fold_right (type.interp A) (type.interp B))
| List_update_nth T => curry3 (@update_nth (type.interp T))
| List_nth_default T => curry3 (@List.nth_default (type.interp T))
@@ -2362,6 +2369,7 @@ Module Compilers.
| Nat.succ ?x => mkAppIdent Nat_succ x
| Nat.add ?x ?y => mkAppIdent Nat_add (x, y)
| Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y)
+ | Nat.max ?x ?y => mkAppIdent Nat_max (x, y)
| S ?x => mkAppIdent Nat_succ x
| @Datatypes.nil ?T
=> let rT := type.reify T in
@@ -2451,6 +2459,12 @@ Module Compilers.
| @List.rev ?A ?ls
=> let rA := type.reify A in
mkAppIdent (@ident.List_rev rA) ls
+ | @List.tl ?A ?ls
+ => let rA := type.reify A in
+ mkAppIdent (@ident.List_tl rA) ls
+ | @List.hd ?A ?x ?ls
+ => let rA := type.reify A in
+ mkAppIdent (@ident.List_hd rA) (x, ls)
| @List.fold_right ?A ?B (fun b a => ?f) ?a0 ?ls
=> let rA := type.reify A in
let rB := type.reify B in
@@ -2505,6 +2519,8 @@ Module Compilers.
Notation partition := List_partition.
Notation app := List_app.
Notation rev := List_rev.
+ Notation tl := List_tl.
+ Notation hd := List_hd.
Notation fold_right := List_fold_right.
Notation update_nth := List_update_nth.
Notation nth_default := List_nth_default.
@@ -2533,6 +2549,7 @@ Module Compilers.
Notation succ := Nat_succ.
Notation add := Nat_add.
Notation mul := Nat_mul.
+ Notation max := Nat_max.
End Nat.
Module Export Notations.
@@ -2576,6 +2593,7 @@ Module Compilers.
| Nat_succ : ident nat nat
| Nat_add : ident (nat * nat) nat
| Nat_mul : ident (nat * nat) nat
+ | Nat_max : ident (nat * nat) nat
| nil {t} : ident () (list t)
| cons {t} : ident (t * list t) (list t)
| fst {A B} : ident (A * B) A
@@ -2651,6 +2669,7 @@ Module Compilers.
| Nat_succ => Nat.succ
| Nat_add => curry2 Nat.add
| Nat_mul => curry2 Nat.mul
+ | Nat_max => curry2 Nat.max
| nil t => curry0 (@Datatypes.nil (type.interp t))
| cons t => curry2 (@Datatypes.cons (type.interp t))
| fst A B => @Datatypes.fst (type.interp A) (type.interp B)
@@ -2709,6 +2728,7 @@ Module Compilers.
| Nat.succ ?x => mkAppIdent Nat_succ x
| Nat.add ?x ?y => mkAppIdent Nat_add (x, y)
| Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y)
+ | Nat.max ?x ?y => mkAppIdent Nat_max (x, y)
| S ?x => mkAppIdent Nat_succ x
| @Datatypes.nil ?T
=> let rT := type.reify T in
@@ -2832,6 +2852,7 @@ Module Compilers.
Notation succ := Nat_succ.
Notation add := Nat_add.
Notation mul := Nat_mul.
+ Notation max := Nat_max.
End Nat.
Module Export Notations.
@@ -2897,6 +2918,8 @@ Module Compilers.
=> AppIdent ident.Nat_add
| for_reification.ident.Nat_mul
=> AppIdent ident.Nat_mul
+ | for_reification.ident.Nat_max
+ => AppIdent ident.Nat_max
| for_reification.ident.nil t
=> AppIdent ident.nil
| for_reification.ident.cons t
@@ -3068,6 +3091,28 @@ Module Compilers.
(fun x l' rev_l' => List_app rev_l' [x])
ls) in
let v := app_and_maybe_cancel v in exact v)
+ | for_reification.ident.List_tl A
+ => ltac:(
+ let v := reify
+ (@expr var)
+ (fun ls
+ => list_rect
+ (fun _ => list (type.interp A))
+ nil
+ (fun _ l' _ => l')
+ ls) in
+ let v := app_and_maybe_cancel v in exact v)
+ | for_reification.ident.List_hd A
+ => ltac:(
+ let v := reify
+ (@expr var)
+ (fun (xls : type.interp A * list (type.interp A))
+ => list_rect
+ (fun _ => type.interp A)
+ (fst xls)
+ (fun x _ _ => x)
+ (snd xls)) in
+ let v := app_and_maybe_cancel v in exact v)
| for_reification.ident.List_fold_right A B
=> ltac:(
let v := reify
@@ -3692,6 +3737,7 @@ Module Compilers.
| ident.Nat_succ as idc
| ident.Nat_add as idc
| ident.Nat_mul as idc
+ | ident.Nat_max as idc
| ident.pred as idc
| ident.Z_shiftr _ as idc
| ident.Z_shiftl _ as idc
@@ -3883,6 +3929,7 @@ Module Compilers.
@@ (ident.fst @@ (Var xyk)))
| ident.Nat_add as idc
| ident.Nat_mul as idc
+ | ident.Nat_max as idc
=> λ (xyk :
(* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * type.nat * (type.nat -> R))%ctype) ,
(ident.snd @@ (Var xyk))
@@ -4585,6 +4632,7 @@ Module Compilers.
| ident.Nat_succ
| ident.Nat_add
| ident.Nat_mul
+ | ident.Nat_max
| ident.bool_rect _
| ident.nat_rect _
| ident.pred
@@ -5219,7 +5267,15 @@ Module Compilers.
let result := ident.interp idc (x, y, z, a) in
inr (inr (fst result), inr (snd result))
| inr (inr (inr (inr x, y), z), a)
- => default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a))
+ => let default := default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a)) in
+ match (z, a) with
+ | (inr xx, inl e)
+ | (inl e, inr xx)
+ => if Z.eqb xx 0
+ then inr (inl e, inr 0%Z)
+ else default
+ | _ => default
+ end
| _ => default_interp idc x_y_z_a
end
| ident.Z_sub_get_borrow as idc
@@ -5300,9 +5356,10 @@ Module Compilers.
end
| ident.Nat_add as idc
| ident.Nat_mul as idc
+ | ident.Nat_max as idc
+ | ident.Z_pow as idc
| ident.Z_eqb as idc
| ident.Z_leb as idc
- | ident.Z_pow as idc
=> fun (x_y : data (_ * _) * expr (_ * _) + (_ + type.interp _) * (_ + type.interp _))
=> match x_y return _ + type.interp _ with
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
@@ -8079,18 +8136,19 @@ Module Montgomery256.
Set Printing Width 100000.
Print montred256.
- (*montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
+ (*
+montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
expr_let x0 := (uint128)(x₁ >> 128) in
expr_let x1 := ((uint128)(x₁) & 340282366920938463463374607431768211455) in
expr_let x2 := 79228162514264337593543950337 *₂₅₆ x0 in
expr_let x3 := ((uint128)(x2) & 340282366920938463463374607431768211455) in
expr_let x4 := 340282366841710300986003757985643364352 *₂₅₆ x1 in
expr_let x5 := ((uint128)(x4) & 340282366920938463463374607431768211455) in
- expr_let x6 := (uint256)(x5 << 128) in
- expr_let x7 := (uint256)(x3 << 128) in
- expr_let x8 := 79228162514264337593543950337 *₂₅₆ x1 in
- expr_let x9 := ADD_256 (x6, x7) in
- expr_let x10 := ADD_256 (x8, x9₁) in
+ expr_let x6 := (uint256)(x3 << 128) in
+ expr_let x7 := 79228162514264337593543950337 *₂₅₆ x1 in
+ expr_let x8 := ADDC_256 (0, x6, x7) in
+ expr_let x9 := (uint256)(x5 << 128) in
+ expr_let x10 := ADDC_256 (0, x9, x8₁) in
expr_let x11 := (uint128)(x10₁ >> 128) in
expr_let x12 := ((uint128)(x10₁) & 340282366920938463463374607431768211455) in
expr_let x13 := 79228162514264337593543950335 *₂₅₆ x11 in
@@ -8099,23 +8157,21 @@ Module Montgomery256.
expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ x12 in
expr_let x17 := (uint128)(x16 >> 128) in
expr_let x18 := ((uint128)(x16) & 340282366920938463463374607431768211455) in
- expr_let x19 := (uint256)(x18 << 128) in
- expr_let x20 := (uint256)(x15 << 128) in
- expr_let x21 := 79228162514264337593543950335 *₂₅₆ x12 in
- expr_let x22 := ADD_256 (x19, x20) in
- expr_let x23 := ADD_256 (x21, x22₁) in
- expr_let x24 := x23₂ +₁₂₈ x22₂ in
- expr_let x25 := 340282366841710300967557013911933812736 *₂₅₆ x11 in
- expr_let x26 := ADD_256 (x14, x25) in
- expr_let x27 := ADD_256 (x17, x26₁) in
- expr_let x28 := ADD_256 (x24, x27₁) in
- expr_let x29 := ADD_256 (x₁, x23₁) in
- expr_let x30 := ADDC_256 (x29₂, x₂, x28₁) in
- expr_let x31 := SELC (x30₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
- expr_let x32 := Z.cast uint256 @@ (fst @@ SUB_256 (x30₁, x31)) in
- ADDM (x32, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ expr_let x19 := (uint256)(x15 << 128) in
+ expr_let x20 := 79228162514264337593543950335 *₂₅₆ x12 in
+ expr_let x21 := ADDC_256 (0, x19, x20) in
+ expr_let x22 := 340282366841710300967557013911933812736 *₂₅₆ x11 in
+ expr_let x23 := ADDC_256 (x21₂, x22, x17) in
+ expr_let x24 := (uint256)(x18 << 128) in
+ expr_let x25 := ADDC_256 (0, x24, x21₁) in
+ expr_let x26 := ADDC_256 (x25₂, x14, x23₁) in
+ expr_let x27 := ADD_256 (x₁, x25₁) in
+ expr_let x28 := ADDC_256 (x27₂, x₂, x26₁) in
+ expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
+ expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in
+ ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
: Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)
-*)
+ *)
End Montgomery256.
(* Extra-specialized ad-hoc pretty-printing *)
@@ -8219,17 +8275,18 @@ Local Open Scope expr_scope.
Print Montgomery256.montred256.
(*
+<<<<<<< HEAD
c.ShiftR($x0, $x_lo, 128);
c.Lower128($x1, $x_lo);
c.Mul128x128($x2, Lower128{RegPinv}, $x0);
c.Lower128($x3, $x2);
c.Mul128x128($x4, RegPinv >> 128, $x1);
c.Lower128($x5, $x4);
-c.ShiftL($x6, $x5, 128);
-c.ShiftL($x7, $x3, 128);
-c.Mul128x128($x8, Lower128{RegPinv}, $x1);
-c.Add256($x9, $x6, $x7);
-c.Add256($x10, $x8, $x9_lo);
+c.ShiftL($x6, $x3, 128);
+c.Mul128x128($x7, Lower128{RegPinv}, $x1);
+c.Addc($x8, $x6, $x7);
+c.ShiftL($x9, $x5, 128);
+c.Addc($x10, $x9, $x8_lo);
c.ShiftR($x11, $x10_lo, 128);
c.Lower128($x12, $x10_lo);
c.Mul128x128($x13, Lower128{RegMod}, $x11);
@@ -8238,22 +8295,18 @@ c.Lower128($x15, $x13);
c.Mul128x128($x16, RegMod << 128, $x12);
c.ShiftR($x17, $x16, 128);
c.Lower128($x18, $x16);
-c.ShiftL($x19, $x18, 128);
-c.ShiftL($x20, $x15, 128);
-c.Mul128x128($x21, Lower128{RegMod}, $x12);
-c.Add256($x22, $x19, $x20);
-c.Add256($x23, $x21, $x22_lo);
-c.Add64($x24, $x23_hi, $x22_hi);
-c.Mul128x128($x25, RegMod << 128, $x11);
-c.Add256($x26, $x14, $x25);
-c.Add256($x27, $x17, $x26_lo);
-c.Add256($x28, $x24, $x27_lo);
-c.Add256($x29, $x_lo, $x23_lo);
-c.Addc($x30, $x_hi, $x28_lo);
-c.Selc($x31,RegZero, RegMod);
-c.Sub($x32, $x30_lo, $x31);
-c.AddM($ret, $x32, RegZero, RegMod);
- : Expr
- (type.type_primitive type.Z * type.type_primitive type.Z ->
- type.type_primitive type.Z)
+c.ShiftL($x19, $x15, 128);
+c.Mul128x128($x20, Lower128{RegMod}, $x12);
+c.Addc($x21, $x19, $x20);
+c.Mul128x128($x22, RegMod << 128, $x11);
+c.Addc($x23, $x22, $x17);
+c.ShiftL($x24, $x18, 128);
+c.Addc($x25, $x24, $x21_lo);
+c.Addc($x26, $x14, $x23_lo);
+c.Add256($x27, $x_lo, $x25_lo);
+c.Addc($x28, $x_hi, $x26_lo);
+c.Selc($x29,RegZero, RegMod);
+c.Sub($x30, $x28_lo, $x29);
+c.AddM($ret, $x30, RegZero, RegMod);
+ : Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)
*)