diff options
author | Jade Philipoom <jadep@google.com> | 2018-03-06 15:12:08 +0100 |
---|---|---|
committer | jadephilipoom <jade.philipoom@gmail.com> | 2018-04-03 09:00:55 -0400 |
commit | ee89ae1f28e6e23917c7cd1bd68bdba1063334c1 (patch) | |
tree | 4fe38ead6974089a3f6bcabedf68ec7411809b7f /src | |
parent | d809687fc691725f0347ceb4d76b35c373965980 (diff) |
Make Montgomery example use row-wise flatten (involves adding Nat.max, List.tl, and List.hd to the pipeline)
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 179 |
1 files changed, 116 insertions, 63 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 4d0af6491..96e54bebb 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -987,9 +987,9 @@ Module Rows. Proof. cbv [eval]; autorewrite with push_map push_sum; reflexivity. Qed. Hint Rewrite eval_app : push_eval. - Definition extract_row (inp : cols) : cols * list Z := (map (@tl _) inp, map (hd 0) inp). + Definition extract_row (inp : cols) : cols * list Z := (map (fun c => tl c) inp, map (fun c => hd 0 c) inp). - Definition max_column_size (x:cols) := fold_right Nat.max 0%nat (map (@length Z) x). + Definition max_column_size (x:cols) := fold_right (fun a b => Nat.max a b) 0%nat (map (fun c => length c) x). Definition from_columns' n start_state : cols * rows := fold_right (fun _ (state : cols * rows) => @@ -1158,10 +1158,11 @@ Module Rows. Local Notation fw := (fun i => weight (S i) / weight i) (only parsing). Definition sum_rows' start_state (row1 row2 : list Z) : list Z * Z := - fold_left (fun (state : list Z * Z) next => + fold_right (fun next (state : list Z * Z) => let i := length (fst state) in (* length of output accumulator tells us the index of [next] *) - let sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in - (fst state ++ [fst sum_carry], snd sum_carry)) (combine row1 row2) start_state. + dlet_nd next := next in (* makes the output correctly bind variables *) + dlet_nd sum_carry := Z.add_with_get_carry_full (fw i) (snd state) (fst next) (snd next) in + (fst state ++ [fst sum_carry], snd sum_carry)) start_state (rev (combine row1 row2)). Definition sum_rows := sum_rows' (nil,0). Definition flatten' (start_state : list Z * Z) (inp : rows) : list Z * Z := @@ -1189,8 +1190,8 @@ Module Rows. /\ snd (sum_rows' start_state row1 row2) = (eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) / (weight nm). Proof. - cbv [sum_rows']. - induction row1 as [|x1 row1]; intros; + cbv [sum_rows' Let_In]. + induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; repeat match goal with | _ => progress autorewrite with natsimplify list @@ -1201,6 +1202,7 @@ Module Rows. specialize (IHrow1 (pred n) (S m)). replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). + rewrite <-fold_left_rev_right. apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; repeat match goal with | H : ?LHS = _ |- _ => @@ -1266,8 +1268,8 @@ Module Rows. nth_default 0 (fst (sum_rows' start_state row1 row2)) i = ((eval nm (row1' ++ row1) + eval nm (row2' ++ row2)) mod (weight (S i))) / (weight i). Proof. - cbv [sum_rows']. - induction row1 as [|x1 row1]; intros; + cbv [sum_rows' Let_In]. + induction row1 as [|x1 row1]; intros; rewrite fold_left_rev_right in *; destruct row2 as [|x2 row2]; distr_length; [ subst n | ]; repeat match goal with | _ => progress cbn [fold_left] @@ -1279,6 +1281,7 @@ Module Rows. specialize (IHrow1 (pred n) (S m)). replace (pred n + S m)%nat with (n + m)%nat in IHrow1 by omega. rewrite (app_cons_app_app _ row1'), (app_cons_app_app _ row2'). + rewrite <-fold_left_rev_right. apply IHrow1; clear IHrow1; autorewrite with cancel_pair distr_length; try omega; repeat match goal with | _ => progress intros @@ -1510,10 +1513,10 @@ Module MulConverted. let p3_a := Associational.mul p1_a p2_a in (* important not to use Positional.carry here; we don't want to accumulate yet *) let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in - fst (Columns.flatten w (Columns.from_associational w n3 p3'_a)). + fst (Rows.flatten w (Rows.from_associational w n3 p3'_a)). Hint Rewrite - @Columns.eval_from_associational + @Rows.eval_from_associational @Associational.eval_carry @Associational.eval_mul @Positional.eval_to_associational @@ -1531,20 +1534,18 @@ Module MulConverted. Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). Proof. cbv [mul_converted]; intros. - rewrite Columns.flatten_mod by auto using Columns.length_from_associational. + rewrite Rows.flatten_mod by eauto using Rows.length_from_associational. autorewrite with push_eval. auto using Z.mod_small. Qed. Hint Rewrite eval_mul_converted : push_eval. - Hint Rewrite @Columns.length_from_associational : distr_length. - Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). Proof. intros; cbv [mul_converted]. - erewrite Columns.flatten_partitions by (auto; distr_length). + rewrite Rows.flatten_partitions with (n:=n3) by (eauto using Rows.length_from_associational; omega). autorewrite with distr_length push_eval natsimplify. rewrite w_0; autorewrite with zsimplify. reflexivity. @@ -1559,7 +1560,7 @@ Module MulConverted. nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). Proof. intros; subst n3; cbv [mul_converted]. - erewrite Columns.flatten_partitions by (auto; distr_length). + rewrite Rows.flatten_partitions with (n:=2%nat) by (eauto using Rows.length_from_associational; omega). autorewrite with distr_length push_eval. rewrite Z.mod_small; omega. Qed. @@ -2245,6 +2246,7 @@ Module Compilers. | primitive {t:type.primitive} (v : interp t) : ident () t | Let_In {tx tC} : ident (tx * (tx -> tC)) tC | Nat_succ : ident nat nat + | Nat_max : ident (nat * nat) nat | Nat_mul : ident (nat * nat) nat | Nat_add : ident (nat * nat) nat | nil {t} : ident () (list t) @@ -2265,6 +2267,8 @@ Module Compilers. | List_partition {A} : ident ((A -> bool) * list A) (list A * list A) | List_app {A} : ident (list A * list A) (list A) | List_rev {A} : ident (list A) (list A) + | List_tl {A} : ident (list A) (list A) + | List_hd {A} : ident (A * list A) A | List_fold_right {A B} : ident ((B * A -> A) * A * list B) A | List_update_nth {T} : ident (nat * (T -> T) * list T) (list T) | List_nth_default {T} : ident (T * list T * nat) T @@ -2313,6 +2317,7 @@ Module Compilers. | Nat_succ => Nat.succ | Nat_add => curry2 Nat.add | Nat_mul => curry2 Nat.mul + | Nat_max => curry2 Nat.max | nil t => curry0 (@Datatypes.nil (type.interp t)) | cons t => curry2 (@Datatypes.cons (type.interp t)) | fst A B => @Datatypes.fst (type.interp A) (type.interp B) @@ -2331,6 +2336,8 @@ Module Compilers. | List_partition A => curry2 (@List.partition (type.interp A)) | List_app A => curry2 (@List.app (type.interp A)) | List_rev A => @List.rev (type.interp A) + | List_tl A => @List.tl (type.interp A) + | List_hd A => curry2 (@List.hd (type.interp A)) | List_fold_right A B => curry3_1 (@List.fold_right (type.interp A) (type.interp B)) | List_update_nth T => curry3 (@update_nth (type.interp T)) | List_nth_default T => curry3 (@List.nth_default (type.interp T)) @@ -2362,6 +2369,7 @@ Module Compilers. | Nat.succ ?x => mkAppIdent Nat_succ x | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) + | Nat.max ?x ?y => mkAppIdent Nat_max (x, y) | S ?x => mkAppIdent Nat_succ x | @Datatypes.nil ?T => let rT := type.reify T in @@ -2451,6 +2459,12 @@ Module Compilers. | @List.rev ?A ?ls => let rA := type.reify A in mkAppIdent (@ident.List_rev rA) ls + | @List.tl ?A ?ls + => let rA := type.reify A in + mkAppIdent (@ident.List_tl rA) ls + | @List.hd ?A ?x ?ls + => let rA := type.reify A in + mkAppIdent (@ident.List_hd rA) (x, ls) | @List.fold_right ?A ?B (fun b a => ?f) ?a0 ?ls => let rA := type.reify A in let rB := type.reify B in @@ -2505,6 +2519,8 @@ Module Compilers. Notation partition := List_partition. Notation app := List_app. Notation rev := List_rev. + Notation tl := List_tl. + Notation hd := List_hd. Notation fold_right := List_fold_right. Notation update_nth := List_update_nth. Notation nth_default := List_nth_default. @@ -2533,6 +2549,7 @@ Module Compilers. Notation succ := Nat_succ. Notation add := Nat_add. Notation mul := Nat_mul. + Notation max := Nat_max. End Nat. Module Export Notations. @@ -2576,6 +2593,7 @@ Module Compilers. | Nat_succ : ident nat nat | Nat_add : ident (nat * nat) nat | Nat_mul : ident (nat * nat) nat + | Nat_max : ident (nat * nat) nat | nil {t} : ident () (list t) | cons {t} : ident (t * list t) (list t) | fst {A B} : ident (A * B) A @@ -2651,6 +2669,7 @@ Module Compilers. | Nat_succ => Nat.succ | Nat_add => curry2 Nat.add | Nat_mul => curry2 Nat.mul + | Nat_max => curry2 Nat.max | nil t => curry0 (@Datatypes.nil (type.interp t)) | cons t => curry2 (@Datatypes.cons (type.interp t)) | fst A B => @Datatypes.fst (type.interp A) (type.interp B) @@ -2709,6 +2728,7 @@ Module Compilers. | Nat.succ ?x => mkAppIdent Nat_succ x | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) + | Nat.max ?x ?y => mkAppIdent Nat_max (x, y) | S ?x => mkAppIdent Nat_succ x | @Datatypes.nil ?T => let rT := type.reify T in @@ -2832,6 +2852,7 @@ Module Compilers. Notation succ := Nat_succ. Notation add := Nat_add. Notation mul := Nat_mul. + Notation max := Nat_max. End Nat. Module Export Notations. @@ -2897,6 +2918,8 @@ Module Compilers. => AppIdent ident.Nat_add | for_reification.ident.Nat_mul => AppIdent ident.Nat_mul + | for_reification.ident.Nat_max + => AppIdent ident.Nat_max | for_reification.ident.nil t => AppIdent ident.nil | for_reification.ident.cons t @@ -3068,6 +3091,28 @@ Module Compilers. (fun x l' rev_l' => List_app rev_l' [x]) ls) in let v := app_and_maybe_cancel v in exact v) + | for_reification.ident.List_tl A + => ltac:( + let v := reify + (@expr var) + (fun ls + => list_rect + (fun _ => list (type.interp A)) + nil + (fun _ l' _ => l') + ls) in + let v := app_and_maybe_cancel v in exact v) + | for_reification.ident.List_hd A + => ltac:( + let v := reify + (@expr var) + (fun (xls : type.interp A * list (type.interp A)) + => list_rect + (fun _ => type.interp A) + (fst xls) + (fun x _ _ => x) + (snd xls)) in + let v := app_and_maybe_cancel v in exact v) | for_reification.ident.List_fold_right A B => ltac:( let v := reify @@ -3692,6 +3737,7 @@ Module Compilers. | ident.Nat_succ as idc | ident.Nat_add as idc | ident.Nat_mul as idc + | ident.Nat_max as idc | ident.pred as idc | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc @@ -3883,6 +3929,7 @@ Module Compilers. @@ (ident.fst @@ (Var xyk))) | ident.Nat_add as idc | ident.Nat_mul as idc + | ident.Nat_max as idc => λ (xyk : (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * type.nat * (type.nat -> R))%ctype) , (ident.snd @@ (Var xyk)) @@ -4585,6 +4632,7 @@ Module Compilers. | ident.Nat_succ | ident.Nat_add | ident.Nat_mul + | ident.Nat_max | ident.bool_rect _ | ident.nat_rect _ | ident.pred @@ -5219,7 +5267,15 @@ Module Compilers. let result := ident.interp idc (x, y, z, a) in inr (inr (fst result), inr (snd result)) | inr (inr (inr (inr x, y), z), a) - => default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a)) + => let default := default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a)) in + match (z, a) with + | (inr xx, inl e) + | (inl e, inr xx) + => if Z.eqb xx 0 + then inr (inl e, inr 0%Z) + else default + | _ => default + end | _ => default_interp idc x_y_z_a end | ident.Z_sub_get_borrow as idc @@ -5300,9 +5356,10 @@ Module Compilers. end | ident.Nat_add as idc | ident.Nat_mul as idc + | ident.Nat_max as idc + | ident.Z_pow as idc | ident.Z_eqb as idc | ident.Z_leb as idc - | ident.Z_pow as idc => fun (x_y : data (_ * _) * expr (_ * _) + (_ + type.interp _) * (_ + type.interp _)) => match x_y return _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) @@ -8079,18 +8136,19 @@ Module Montgomery256. Set Printing Width 100000. Print montred256. - (*montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + (* +montred256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, expr_let x0 := (uint128)(x₁ >> 128) in expr_let x1 := ((uint128)(x₁) & 340282366920938463463374607431768211455) in expr_let x2 := 79228162514264337593543950337 *₂₅₆ x0 in expr_let x3 := ((uint128)(x2) & 340282366920938463463374607431768211455) in expr_let x4 := 340282366841710300986003757985643364352 *₂₅₆ x1 in expr_let x5 := ((uint128)(x4) & 340282366920938463463374607431768211455) in - expr_let x6 := (uint256)(x5 << 128) in - expr_let x7 := (uint256)(x3 << 128) in - expr_let x8 := 79228162514264337593543950337 *₂₅₆ x1 in - expr_let x9 := ADD_256 (x6, x7) in - expr_let x10 := ADD_256 (x8, x9₁) in + expr_let x6 := (uint256)(x3 << 128) in + expr_let x7 := 79228162514264337593543950337 *₂₅₆ x1 in + expr_let x8 := ADDC_256 (0, x6, x7) in + expr_let x9 := (uint256)(x5 << 128) in + expr_let x10 := ADDC_256 (0, x9, x8₁) in expr_let x11 := (uint128)(x10₁ >> 128) in expr_let x12 := ((uint128)(x10₁) & 340282366920938463463374607431768211455) in expr_let x13 := 79228162514264337593543950335 *₂₅₆ x11 in @@ -8099,23 +8157,21 @@ Module Montgomery256. expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ x12 in expr_let x17 := (uint128)(x16 >> 128) in expr_let x18 := ((uint128)(x16) & 340282366920938463463374607431768211455) in - expr_let x19 := (uint256)(x18 << 128) in - expr_let x20 := (uint256)(x15 << 128) in - expr_let x21 := 79228162514264337593543950335 *₂₅₆ x12 in - expr_let x22 := ADD_256 (x19, x20) in - expr_let x23 := ADD_256 (x21, x22₁) in - expr_let x24 := x23₂ +₁₂₈ x22₂ in - expr_let x25 := 340282366841710300967557013911933812736 *₂₅₆ x11 in - expr_let x26 := ADD_256 (x14, x25) in - expr_let x27 := ADD_256 (x17, x26₁) in - expr_let x28 := ADD_256 (x24, x27₁) in - expr_let x29 := ADD_256 (x₁, x23₁) in - expr_let x30 := ADDC_256 (x29₂, x₂, x28₁) in - expr_let x31 := SELC (x30₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in - expr_let x32 := Z.cast uint256 @@ (fst @@ SUB_256 (x30₁, x31)) in - ADDM (x32, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) + expr_let x19 := (uint256)(x15 << 128) in + expr_let x20 := 79228162514264337593543950335 *₂₅₆ x12 in + expr_let x21 := ADDC_256 (0, x19, x20) in + expr_let x22 := 340282366841710300967557013911933812736 *₂₅₆ x11 in + expr_let x23 := ADDC_256 (x21₂, x22, x17) in + expr_let x24 := (uint256)(x18 << 128) in + expr_let x25 := ADDC_256 (0, x24, x21₁) in + expr_let x26 := ADDC_256 (x25₂, x14, x23₁) in + expr_let x27 := ADD_256 (x₁, x25₁) in + expr_let x28 := ADDC_256 (x27₂, x₂, x26₁) in + expr_let x29 := SELC (x28₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x30 := Z.cast uint256 @@ (fst @@ SUB_256 (x28₁, x29)) in + ADDM (x30, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) : Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z) -*) + *) End Montgomery256. (* Extra-specialized ad-hoc pretty-printing *) @@ -8219,17 +8275,18 @@ Local Open Scope expr_scope. Print Montgomery256.montred256. (* +<<<<<<< HEAD c.ShiftR($x0, $x_lo, 128); c.Lower128($x1, $x_lo); c.Mul128x128($x2, Lower128{RegPinv}, $x0); c.Lower128($x3, $x2); c.Mul128x128($x4, RegPinv >> 128, $x1); c.Lower128($x5, $x4); -c.ShiftL($x6, $x5, 128); -c.ShiftL($x7, $x3, 128); -c.Mul128x128($x8, Lower128{RegPinv}, $x1); -c.Add256($x9, $x6, $x7); -c.Add256($x10, $x8, $x9_lo); +c.ShiftL($x6, $x3, 128); +c.Mul128x128($x7, Lower128{RegPinv}, $x1); +c.Addc($x8, $x6, $x7); +c.ShiftL($x9, $x5, 128); +c.Addc($x10, $x9, $x8_lo); c.ShiftR($x11, $x10_lo, 128); c.Lower128($x12, $x10_lo); c.Mul128x128($x13, Lower128{RegMod}, $x11); @@ -8238,22 +8295,18 @@ c.Lower128($x15, $x13); c.Mul128x128($x16, RegMod << 128, $x12); c.ShiftR($x17, $x16, 128); c.Lower128($x18, $x16); -c.ShiftL($x19, $x18, 128); -c.ShiftL($x20, $x15, 128); -c.Mul128x128($x21, Lower128{RegMod}, $x12); -c.Add256($x22, $x19, $x20); -c.Add256($x23, $x21, $x22_lo); -c.Add64($x24, $x23_hi, $x22_hi); -c.Mul128x128($x25, RegMod << 128, $x11); -c.Add256($x26, $x14, $x25); -c.Add256($x27, $x17, $x26_lo); -c.Add256($x28, $x24, $x27_lo); -c.Add256($x29, $x_lo, $x23_lo); -c.Addc($x30, $x_hi, $x28_lo); -c.Selc($x31,RegZero, RegMod); -c.Sub($x32, $x30_lo, $x31); -c.AddM($ret, $x32, RegZero, RegMod); - : Expr - (type.type_primitive type.Z * type.type_primitive type.Z -> - type.type_primitive type.Z) +c.ShiftL($x19, $x15, 128); +c.Mul128x128($x20, Lower128{RegMod}, $x12); +c.Addc($x21, $x19, $x20); +c.Mul128x128($x22, RegMod << 128, $x11); +c.Addc($x23, $x22, $x17); +c.ShiftL($x24, $x18, 128); +c.Addc($x25, $x24, $x21_lo); +c.Addc($x26, $x14, $x23_lo); +c.Add256($x27, $x_lo, $x25_lo); +c.Addc($x28, $x_hi, $x26_lo); +c.Selc($x29,RegZero, RegMod); +c.Sub($x30, $x28_lo, $x29); +c.AddM($ret, $x30, RegZero, RegMod); + : Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z) *) |