From e4651284bb30a664ef4ec190dce4b01b02822f53 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 30 May 2018 21:33:17 +0200 Subject: Proved pre-fancy barrett reduction correct (except 1 admit for bounds that are correct but for which bounds relaxation loses necessary information) and add explanatory comments. --- src/Experiments/SimplyTypedArithmetic.v | 808 +++++++++++++++++++++----------- 1 file changed, 527 insertions(+), 281 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index bb2255547..c277a23ab 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -8836,9 +8836,9 @@ Module PreFancy. match e with | Scalar _ s => Scalar s | LetInAppIdentZ _ t r idc x f => - of_straightline_ident idc t r x (fun y => of_straightline (f y)) + of_straightline_ident idc t r[0~>wordmax-1]%zrange x (fun y => of_straightline (f y)) | LetInAppIdentZZ _ t r idc x f => - of_straightline_ident idc t r x (fun y => of_straightline (f y)) + of_straightline_ident idc t (r[0~>wordmax-1], r[0~>1])%zrange x (fun y => of_straightline (f y)) end. End with_var. @@ -8988,14 +8988,16 @@ Module PreFancy. (word_range, flag_range) (ident.Z.add_get_carry_concrete wordmax) | ok_addc : - forall c x y : scalar type.Z, + forall (c x y : scalar type.Z) outr, in_flag_range (get_range c) -> in_word_range (get_range x) -> in_word_range (get_range y) -> + lower outr = 0 -> + (0 <= upper (get_range c) + upper (get_range x) + upper (get_range y) <= upper outr \/ outr = word_range) -> ok_ident _ (type.prod type.Z type.Z) (Pair (Pair c x) y) - (word_range, flag_range) + (outr, flag_range) (ident.Z.add_with_get_carry_concrete wordmax) | ok_sub : forall x y : scalar type.Z, @@ -9017,11 +9019,15 @@ Module PreFancy. (word_range, flag_range) (ident.Z.sub_with_get_borrow_concrete wordmax) | ok_rshi : - forall (x : scalar (type.prod type.Z type.Z)) n, + forall (x : scalar (type.prod type.Z type.Z)) n outr, in_word_range (fst (get_range x)) -> in_word_range (snd (get_range x)) -> - 0 <= n < 2 * log2wordmax -> - ok_ident (type.prod type.Z type.Z) type.Z x word_range (ident.Z.rshi_concrete wordmax n) + (* note : using [outr] rather than [word_range] allows for cases where the result has been put in a smaller word size. *) + lower outr = 0 -> + 0 <= n -> + ((0 <= (upper (snd (get_range x)) + upper (fst (get_range x)) * wordmax) / 2^n <= upper outr) + \/ outr = word_range) -> + ok_ident (type.prod type.Z type.Z) type.Z x outr (ident.Z.rshi_concrete wordmax n) | ok_selc : forall (x : scalar (type.prod type.Z type.Z)) (y z : scalar type.Z), in_flag_range (snd (get_range x)) -> @@ -9078,14 +9084,16 @@ Module PreFancy. | ok_of_scalar : forall t s, ok_scalar s -> @ok_expr t (Scalar s) | ok_letin_z : forall s d r idc x f, ok_ident _ type.Z x r idc -> + (r <=? word_range)%zrange = true -> ok_scalar x -> - (forall y, has_range r y -> ok_expr (f y)) -> + (forall y, has_range (t:=type.Z) r y -> ok_expr (f y)) -> ok_expr (@LetInAppIdentZ _ _ s d r idc x f) | ok_letin_zz : forall s d r idc x f, - ok_ident _ (type.prod type.Z type.Z) x r idc -> + ok_ident _ (type.prod type.Z type.Z) x (r, flag_range) idc -> + (r <=? word_range)%zrange = true -> ok_scalar x -> - (forall y, has_range r y -> ok_expr (f y)) -> - ok_expr (@LetInAppIdentZZ _ _ s d r idc x f) + (forall y, has_range (t:=type.Z * type.Z) (r, flag_range) y -> ok_expr (f y)) -> + ok_expr (@LetInAppIdentZZ _ _ s d (r, flag_range) idc x f) . Ltac invert H := @@ -9205,14 +9213,25 @@ Module PreFancy. auto. Qed. - Lemma has_word_range_rshi n x y : + Lemma has_range_rshi r n x y : 0 <= n -> - @has_range type.Z word_range (Z.rshi wordmax x y n). + 0 <= x -> + 0 <= y -> + lower r = 0 -> + (0 <= (y + x * wordmax) / 2^n <= upper r \/ r = word_range) -> + @has_range type.Z r (Z.rshi wordmax x y n). Proof. pose proof wordmax_gt_2. - intros; rewrite Z.rshi_correct by omega. - match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. - cbn [has_range lower upper]; lia. + intros. cbv [has_range]. + rewrite Z.rshi_correct by omega. + match goal with |- context [?x mod ?m] => + pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + split; [lia|]. + intuition. + { destruct (Z_lt_dec (upper r) wordmax); [ | lia]. + rewrite Z.mod_small by (split; Z.zero_bounds; omega). + omega. } + { subst r. cbn [upper]. omega. } Qed. Lemma in_word_range_spec r : @@ -9417,7 +9436,15 @@ Module PreFancy. autorewrite with to_div_mod. match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. rewrite Z.div_between_0_if by omega. - split; break_match; lia. } + match goal with H : _ \/ _ |- _ => destruct H; subst end. + { split; break_match; try lia. + destruct (Z_lt_dec (upper outr) wordmax). + { match goal with |- _ <= ?y mod _ <= ?u => + assert (y <= u) by nia end. + rewrite Z.mod_small by omega. omega. } + { match goal with|- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + omega. } } + { split; break_match; cbn; lia. } } { autorewrite with to_div_mod. match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. @@ -9429,7 +9456,12 @@ Module PreFancy. match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. rewrite Z.div_sub_small by omega. split; break_match; lia. } - { apply has_word_range_rshi; omega. } + { apply has_range_rshi; try nia; [ ]. + match goal with H : context [upper ?ra + upper ?rb * wordmax] |- context [?a + ?b * wordmax] => + assert ((a + b * wordmax) / 2^n <= (upper ra + upper rb * wordmax) / 2^n) by (apply Z.div_le_mono; Z.zero_bounds; nia) + end. + match goal with H : _ \/ ?P |- _ \/ ?P => destruct H; [left|tauto] end. + split; Z.zero_bounds; nia. } { rewrite Z.zselect_correct. break_match; omega. } { cbn [interp_scalar fst snd get_range] in *. rewrite Z.zselect_correct. break_match; omega. } @@ -9529,20 +9561,34 @@ Module PreFancy. rewrite !Z.leb_refl; reflexivity. } } Qed. - Lemma of_straightline_ident_mul_correct t x y g : + Lemma halved_mul_range x y : + ok_scalar (Pair x y) -> + is_halved x -> + is_halved y -> + 0 <= interp_scalar x * interp_scalar y < wordmax. + Proof. + intro Hok; invert Hok. intros. + repeat match goal with H : _ |- _ => apply is_halved_has_range in H; [|assumption] end. + cbv [has_range lower upper] in *. + pose proof half_bits_squared. nia. + Qed. + + Lemma of_straightline_ident_mul_correct r t x y g : is_halved x -> is_halved y -> ok_scalar (Pair x y) -> + (word_range <=? r)%zrange = true -> @has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) -> - @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) = + @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t r (Pair x y) g) = @interp interp_cast _ (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))). Proof. - intros Hx Hy Hok ?; invert Hok; cbn [interp_scalar of_straightline_ident]. + intros Hx Hy Hok ? ?; invert Hok; cbn [interp_scalar of_straightline_ident]; destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ]; rewrite ?Pxlow, ?Pxhigh; destruct (is_halved_cases y Hy ltac:(assumption)) as [ [? [Pylow [Pyhigh Pyi] ] ] | [? [Pylow [Pyhigh Pyi] ] ] ]; rewrite ?Pylow, ?Pyhigh; - cbn; rewrite Pxi, Pyi; rewrite interp_cast_noop by auto; reflexivity. + cbn; rewrite Pxi, Pyi; assert (0 <= interp_scalar x * interp_scalar y < wordmax) by (auto using halved_mul_range); + rewrite interp_cast_noop by (cbv [is_tighter_than_bool] in *; cbn [has_range upper lower] in *; rewrite andb_true_iff in *; intuition; Z.ltb_to_lt; lia); reflexivity. Qed. Lemma has_word_range_mod_small x: @@ -9629,25 +9675,48 @@ Module PreFancy. replace y with (fst y, snd y) by (destruct y; reflexivity) end; autorewrite with to_div_mod; solve [repeat (f_equal; try ring)]. - Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g : + Fixpoint is_tighter_than_bool_range_type t : range_type t -> range_type t -> bool := + match t with + | type.type_primitive type.Z => (fun r1 r2 => (r1 <=? r2)%zrange) + | type.prod a b => fun r1 r2 => + (is_tighter_than_bool_range_type a (fst r1) (fst r2)) + && (is_tighter_than_bool_range_type b (snd r1) (snd r2)) + | _ => fun _ _ => true + end. + + Definition range_ok {t} : range_type t -> Prop := + match t with + | type.type_primitive type.Z => fun r => in_word_range r + | type.prod type.Z type.Z => fun r => in_word_range (fst r) /\ snd r = flag_range + | _ => fun _ => False + end. + + Lemma of_straightline_ident_correct s d t x r r' (idc : ident.ident s d) g : ok_ident s d x r idc -> + range_ok r' -> + is_tighter_than_bool_range_type d r r' = true -> ok_scalar x -> - @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r x g) = + @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r' x g) = @interp interp_cast _ (g (ident.interp idc (interp_scalar x))). Proof. intros. pose proof wordmax_half_bits_pos. pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)). - induction H; try solve [auto using of_straightline_ident_mul_correct]; + match goal with H : ok_ident _ _ _ _ _ |- _ => induction H end; + try solve [auto using of_straightline_ident_mul_correct]; + cbv [is_tighter_than_bool_range_type is_tighter_than_bool range_ok] in *; cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell] in *; intros; rewrite ?Z.eqb_refl; cbn [andb]; try match goal with |- context [invert_shift] => break_match end; cbn [interp interp_ident]; try destruct_scalar; - repeat match goal with + repeat match goal with | _ => progress (cbn [fst snd interp_scalar] in * ) | _ => progress break_match; [ ] | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.ltb_to_lt + | H : _ /\ _ |- _ => destruct H + | _ => rewrite andb_true_iff in * | _ => rewrite interp_cast_noop with (r:=flag_range) in * by (apply has_flag_range_cc_m'; auto; extract_ok_scalar) | _ => rewrite interp_cast_noop with (r:=flag_range) in * @@ -9662,10 +9731,10 @@ Module PreFancy. by (eapply has_range_loosen; [apply has_range_interp_scalar; extract_ok_scalar|]; assumption) - | _ => rewrite interp_cast_noop by assumption - | _ => rewrite interp_cast2_noop by assumption + | _ => rewrite interp_cast_noop by (cbn [has_range fst snd] in *; split; lia) + | _ => rewrite interp_cast2_noop by (cbn [has_range fst snd] in *; split; lia) | _ => reflexivity - end. + end. Qed. Lemma of_straightline_correct {t} (e : expr t) : @@ -9676,12 +9745,15 @@ Module PreFancy. induction 1; cbn [of_straightline]; intros; repeat match goal with | _ => progress cbn [Straightline.expr.interp] - | _ => rewrite of_straightline_ident_correct by auto - | _ => rewrite interp_cast_noop by auto using ident_interp_has_range - | _ => rewrite interp_cast2_noop by auto using ident_interp_has_range - | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by auto using ident_interp_has_range + | _ => erewrite of_straightline_ident_correct + by (cbv [range_ok is_tighter_than_bool_range_type]; + eauto using in_word_range_word_range; + try apply andb_true_iff; auto) + | _ => rewrite interp_cast_noop by eauto using has_range_loosen, ident_interp_has_range + | _ => rewrite interp_cast2_noop by eauto using has_range_loosen, ident_interp_has_range + | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by eauto using has_range_loosen, ident_interp_has_range | _ => reflexivity - end. + end. Qed. End proofs. @@ -9726,10 +9798,12 @@ Module PreFancy. Proof. induction 1; intros; cbn [of_straightline interp]. { apply replace_interp_cast_scalar; auto. } - { rewrite !of_straightline_ident_correct by auto. + { erewrite !of_straightline_ident_correct by (eauto; cbv [range_ok]; apply in_word_range_word_range). rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. eauto using ident_interp_has_range. } - { rewrite !of_straightline_ident_correct by auto. + { erewrite !of_straightline_ident_correct by + (eauto; try solve [cbv [range_ok]; split; auto using in_word_range_word_range]; + cbv [is_tighter_than_bool_range_type]; apply andb_true_iff; split; auto). rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. eauto using ident_interp_has_range. } Qed. @@ -10248,6 +10322,34 @@ Module Prod. (Instr (SUB 0) lo (hi, y) (Instr ADDM lo (lo, RegZero, RegMod) (Ret lo))))))). + + (* Barrett reduction -- this is only the "reduce" part, excluding the initial multiplication. *) + Definition MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : @Fancy.expr register := + let q1Bottom256 := scratchp1 in + let muSelect := scratchp2 in + let q2 := scratchp3 in + let q2High := scratchp4 in + let q2High2 := scratchp5 in + let q3 := scratchp1 in + let r2 := scratchp2 in + let r2High := scratchp3 in + let maybeM := scratchp1 in + Instr SELM muSelect (RegMuLow, RegZero) + (Instr (RSHI 255) q1Bottom256 (xHigh, x) + (Mul256x256 q2 q2High q1Bottom256 RegMuLow scratchp5 + (Instr (RSHI 255) q2High2 (RegZero, xHigh) + (Instr (ADD 0) q2High (q2High, q1Bottom256) + (Instr (ADDC 0) q2High2 (q2High2, RegZero) + (Instr (ADD 0) q2High (q2High, muSelect) + (Instr (ADDC 0) q2High2 (q2High2, RegZero) + (Instr (RSHI 1) q3 (q2High2, q2High) + (Mul256x256 r2 r2High RegMod q3 scratchp4 + (Instr (SUB 0) muSelect (x, r2) + (Instr (SUBC 0) xHigh (xHigh, r2High) + (Instr SELL maybeM (RegMod, RegZero) + (Instr (SUB 0) q3 (muSelect, maybeM) + (Instr ADDM x (q3, RegZero, RegMod) + (Ret x))))))))))))))). End Prod. Module ProdEquiv. @@ -10434,6 +10536,74 @@ Module ProdEquiv. subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. lia. } Qed. + + Lemma mulll_comm rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma mulhh_comm rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma mullh_mulhl rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma add_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. + Proof. + intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm. + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. + Qed. + + Lemma addc_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. + Proof. + intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)). + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. + Qed. + + Ltac push_value_unused := + repeat match goal with + | |- ~ In _ _ => cbn; intuition; congruence + | _ => apply ProdEquiv.value_unused_overwrite + | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] + | _ => apply ProdEquiv.value_unused_ret; congruence + end. + + Ltac remember_single_result := + match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (Fancy.spec i args cc) as x eqn:Heqx; + remember (x mod w) as y + end. + Ltac step_both_sides := + match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => + rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2); + cbn - [Fancy.interp Fancy.spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_single_result; + lazymatch goal with + | |- context [Fancy.spec i _ _] => + let Heqa1 := fresh in + let Heqa2 := fresh in + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; + cbn in Heqa1; cbn in Heqa2; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; + let a1 := match type of Heqa1 with _ = ?a1 => a1 end in + let a2 := match type of Heqa2 with _ = ?a2 => a2 end in + (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) + | _ => idtac + end + end. End ProdEquiv. (* Lemmas to help prove that a fancy and prefancy expression have the @@ -11073,7 +11243,7 @@ Ltac solve_rbarrett_red_nocache := solve_rop_nocache BarrettReduction.rbarrett_r Module Barrett256. - Definition M := (2^256-2^224+2^192+2^96-1). + Definition M := Eval lazy in (2^256-2^224+2^192+2^96-1). Definition machine_wordsize := 256. Derive barrett_red256 @@ -11081,6 +11251,213 @@ Module Barrett256. As barrett_red256_correct. Proof. Time solve_rbarrett_red machine_wordsize. Time Qed. + Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize). + Definition barrett_red256_prefancy' := PreFancy.of_Expr machine_wordsize [M; muLow] barrett_red256. + + Derive barrett_red256_prefancy + SuchThat (barrett_red256_prefancy = barrett_red256_prefancy' type.interp) + As barrett_red256_prefancy_eq. + Proof. lazy - [type.interp]; reflexivity. Qed. + + Lemma barrett_reduce_correct_specialized : + forall (xLow xHigh : Z), + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *; + try omega; + try match goal with + | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega + end; lazy; try split; congruence. + Qed. + + (* Note: If this is not factored out, then for some reason Qed takes forever in barrett_red256_correct_full. *) + Lemma barrett_red256_correct_proj2 : + forall xy : type.interp (type.prod type.Z type.Z), + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + expr.Interp (@ident.interp) barrett_red256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (fun xy => BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 (fst xy) (snd xy)) xy. + Proof. intros; destruct (barrett_red256_correct xy); assumption. Qed. + Lemma barrett_red256_correct_proj2' : + forall x y : Z, + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + (x, y) = true -> + expr.Interp (@ident.interp) barrett_red256 (x, y) = BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 x y. + Proof. intros; rewrite barrett_red256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. + + Lemma barrett_red256_correct_full : + forall (xLow xHigh : Z), + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + expr.interp (@ident.interp) (barrett_red256 type.interp) (xLow, xHigh) = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + rewrite <-barrett_reduce_correct_specialized by assumption. + rewrite <-barrett_red256_correct_proj2'. + { cbv [expr.Interp type.uncurried_domain type.uncurry type.final_codomain]. + reflexivity. } + { cbn. rewrite !andb_true_iff. cbv [machine_wordsize M] in *. + cbn in *. repeat split; apply Z.leb_le; omega. } + Qed. + + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step' := + match goal with + | _ => assumption + | |- _ <= _ <= _ \/ @eq zrange _ _ => + right; lazy; try split; congruence + | |- _ <= _ <= _ \/ @eq zrange _ _ => + left; lazy; try split; congruence + | |- context [PreFancy.ok_ident] => constructor + | |- context [PreFancy.ok_scalar] => constructor; try omega + | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] + | |- context [PreFancy.is_halved] => constructor + | |- context [PreFancy.in_word_range] => lazy; reflexivity + | |- context [PreFancy.in_flag_range] => lazy; reflexivity + | |- context [PreFancy.get_range] => + cbn [PreFancy.get_range lower upper fst snd ZRange.map] + | x : type.interp (type.prod _ _) |- _ => destruct x + | |- (_ <=? _)%zrange = true => + match goal with + | |- context [PreFancy.get_range_var] => + cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower machine_wordsize M muLow] in *; cbn; + apply andb_true_iff; split; apply Z.leb_le + | _ => lazy + end; omega || reflexivity + | |- @eq zrange _ _ => lazy; reflexivity + | |- _ <= _ => cbv [machine_wordsize]; omega + | |- _ <= _ <= _ => cbv [machine_wordsize]; omega + end; intros. + + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step := + match goal with + | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' + end; intros; cbn [Nat.max]. + + Lemma barrett_red256_prefancy_correct : + forall xLow xHigh dummy_arrow, + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (barrett_red256_prefancy (xLow, xHigh) dummy_arrow) = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. rewrite barrett_red256_prefancy_eq; cbv [barrett_red256_prefancy']. + erewrite PreFancy.of_Expr_correct. + { apply barrett_red256_correct_full; try assumption; reflexivity. } + { reflexivity. } + { lazy; reflexivity. } + { lazy; reflexivity. } + { repeat constructor. } + { cbv [In M muLow]; intros; intuition; subst; cbv; congruence. } + { let r := (eval compute in (2 ^ machine_wordsize)) in + replace (2^machine_wordsize) with r in * by reflexivity. + cbv [M] in *. + assert (lower r[0~>1] = 0) by reflexivity. + repeat (ok_expr_step; [ ]). + ok_expr_step. { exact admit. (* TODO: the actual bounds on the second argument are lower, but relax_zrange steps have lost that information. *) } + repeat (ok_expr_step; [ ]). + ok_expr_step. + lazy; congruence. + constructor. + constructor. } + { lazy. omega. } + Qed. + + Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) := + Fancy.of_Expr 3%positive + (fun z => if z =? muLow then Some RegMuLow else if z =? M then Some RegMod else if z =? 0 then Some RegZero else None) + [M; muLow] + barrett_red256 + (xLow, xHigh)%positive + (fun _ _ => tt) + error. + Derive barrett_red256_fancy + SuchThat (forall xLow xHigh RegMuLow RegMod RegZero, + barrett_red256_fancy xLow xHigh RegMuLow RegMod RegZero = barrett_red256_fancy' xLow xHigh RegMuLow RegMod RegZero) + As barrett_red256_fancy_eq. + Proof. + intros. + lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB Fancy.SUBC + Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU + Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. + reflexivity. + Qed. + + Import Fancy.Registers. + + Definition barrett_red256_alloc' xLow xHigh RegMuLow := + fun errorP errorR => + Fancy.allocate register + positive Pos.eqb + errorR + (barrett_red256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) + [r2;r3;r4;r5;r6;r7;r8;r9;r10;r5;r11;r6;r12;r13;r14;r15;r16;r17;r18;r19;r20;r21;r22;r23;r24;r25;r26;r27;r28;r29] + (fun n => if n =? 1000 then xLow + else if n =? 1001 then xHigh + else if n =? 1002 then RegMuLow + else if n =? 1003 then RegMod + else if n =? 1004 then RegZero + else errorR). + Derive barrett_red256_alloc + SuchThat (barrett_red256_alloc = barrett_red256_alloc') + As barrett_red256_alloc_eq. + Proof. + intros. + cbv [barrett_red256_alloc' barrett_red256_fancy]. + cbn. subst barrett_red256_alloc. + reflexivity. + Qed. + + Set Printing Depth 1000. + Import ProdEquiv. + + Local Ltac solve_bounds := + match goal with + | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega + | _ => assumption + end. + + Lemma barrett_red256_alloc_equivalent errorP errorR cc_start_state start_context : + forall x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg, + NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> + 0 <= start_context x < 2^machine_wordsize -> + 0 <= start_context xHigh < 2^machine_wordsize -> + 0 <= start_context RegMuLow < 2^machine_wordsize -> + ProdEquiv.interp256 (barrett_red256_alloc r0 r1 r30 errorP errorR) cc_start_state + (fun r => if reg_eqb r r0 + then start_context x + else if reg_eqb r r1 + then start_context xHigh + else if reg_eqb r r30 + then start_context RegMuLow + else start_context r) + = ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context. + Proof. + intros. + let r := eval compute in (2^machine_wordsize) in + replace (2^machine_wordsize) with r in * by reflexivity. + cbv [Prod.MulMod barrett_red256_alloc]. + + (* Extract proofs that no registers are equal to each other *) + repeat match goal with + | H : NoDup _ |- _ => inversion H; subst; clear H + | H : ~ In _ _ |- _ => cbv [In] in H + | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H + | H : ~ False |- _ => clear H + end. + + step_both_sides. + + (* TODO: To prove equivalence between these two, we need to either relocate the RSHI instructions so they're in the same places or use instruction commutativity to push them down. *) + + Admitted. + Import PrintingNotations. Set Printing Width 1000. Open Scope expr_scope. @@ -11120,10 +11497,6 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) *) - Definition muLow := (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize). - Definition barrett_red256_prefancy := - Eval lazy in (PreFancy.of_Expr machine_wordsize [M;muLow] barrett_red256). - Import PreFancy. Import PreFancy.Notations. Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). @@ -11779,6 +12152,11 @@ Module Montgomery256. Ltac ok_expr_step' := match goal with | _ => assumption + | |- _ <= _ <= _ \/ @eq zrange _ _ => + right; lazy; try split; congruence + | |- _ <= _ <= _ \/ @eq zrange _ _ => + left; lazy; try split; congruence + | |- lower r[0~>_]%zrange = 0 => reflexivity | |- context [PreFancy.ok_ident] => constructor | |- context [PreFancy.ok_scalar] => constructor; try omega | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] @@ -11873,76 +12251,7 @@ Module Montgomery256. reflexivity. Qed. - (* TODO : move *) - Lemma mulll_comm rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma mulhh_comm rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma mullh_mulhl rd x y cont cc ctx : - ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. - Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. - - Lemma add_comm rd x y cont cc ctx : - 0 <= ctx x < 2^256 -> - 0 <= ctx y < 2^256 -> - ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. - Proof. - intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm. - rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. - Qed. - - Lemma addc_comm rd x y cont cc ctx : - 0 <= ctx x < 2^256 -> - 0 <= ctx y < 2^256 -> - ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. - Proof. - intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)). - rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. - Qed. - - SearchAbout reg_eqb. - - Local Ltac push_value_unused := - repeat match goal with - | |- ~ In _ _ => cbn; intuition; congruence - | _ => apply ProdEquiv.value_unused_overwrite - | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] - | _ => apply ProdEquiv.value_unused_ret; congruence - end. - - Local Ltac remember_single_result := - match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => - let x := fresh "x" in - let y := fresh "y" in - let Heqx := fresh "Heqx" in - remember (Fancy.spec i args cc) as x eqn:Heqx; - remember (x mod w) as y - end. - Local Ltac step_both_sides := - match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => - rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2); - cbn - [Fancy.interp Fancy.spec]; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; - remember_single_result; - lazymatch goal with - | |- context [Fancy.spec i _ _] => - let Heqa1 := fresh in - let Heqa2 := fresh in - remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; - remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; - cbn in Heqa1; cbn in Heqa2; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; - repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; - let a1 := match type of Heqa1 with _ = ?a1 => a1 end in - let a2 := match type of Heqa2 with _ = ?a2 => a2 end in - (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) - | _ => idtac - end - end. + Import ProdEquiv. Local Ltac solve_bounds := match goal with @@ -12170,143 +12479,19 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * End Montgomery256. -(* Extra-specialized ad-hoc pretty-printing *) -Module FancyPrintingNotations. - Export ident. - Open Scope expr_scope. - Open Scope ctype_scope. - Notation "'RegMod'" := - (AppIdent - (primitive 115792089210356248762697446949407573530086143415290314195533631308867097853951) - TT) (only printing, at level 9) : expr_scope. - Notation "'RegMuLow'" := - (AppIdent - (primitive 26959946667150639793205513449348445388433292963828203772348655992835) - TT) (only printing, at level 9) : expr_scope. - Notation "'RegPinv'" := - (AppIdent - (primitive 115792089210356248768974548684794254293921932838497980611635986753331132366849) - TT) (only printing, at level 9) : expr_scope. - Notation "'RegZero'" := - (AppIdent - (primitive 0) - TT) (only printing, at level 9) : expr_scope. - Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : expr_scope. - Notation "'c.Lower(RegMod)'" := - (AppIdent - (primitive 79228162514264337593543950335) - TT) (only printing, at level 9) : expr_scope. - Notation "'c.Upper(RegMod)'" := - (AppIdent - (primitive 340282366841710300967557013911933812736) - TT) (only printing, at level 9) : expr_scope. - Notation "'c.Lower(RegMuLow)'" := - (AppIdent - (primitive 340282366841710300930663525764514709507) - TT) (only printing, at level 9) : expr_scope. - Notation "'c.Upper(RegMuLow)'" := - (AppIdent - (primitive 79228162514264337589248983038) - TT) (only printing, at level 9) : expr_scope. - Notation "'c.Lower(RegPinv)'" := - (AppIdent - (primitive 79228162514264337593543950337) - TT) (only printing, at level 9) : expr_scope. - Notation "'c.Upper(RegPinv)'" := - (AppIdent - (primitive 340282366841710300986003757985643364352) - TT) (only printing, at level 9) : expr_scope. - Notation "'uint256'" - := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : ctype_scope. - Notation "'uint128'" - := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : ctype_scope. - Notation "$ n" := (Var n) (at level 10, format "$ n") : expr_scope. - Notation "$ n" := (Z.cast _ @@ Var n) (at level 10, format "$ n") : expr_scope. - Notation "$ n '_lo'" := (fst @@ (Var n))%expr (at level 10, format "$ n _lo") : expr_scope. - Notation "$ n '_hi'" := (snd @@ (Var n))%expr (at level 10, format "$ n _hi") : expr_scope. - Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Var n)))%expr (at level 10, format "$ n _lo") : expr_scope. - Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Var n)))%expr (at level 10, format "$ n _hi") : expr_scope. - Notation "$ n '_lo'" := (fst @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _lo") : expr_scope. - Notation "$ n '_hi'" := (snd @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _hi") : expr_scope. - Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _lo") : expr_scope. - Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _hi") : expr_scope. - Notation "'c.Mul128x128(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast uint256 @@ (Z.mul @@ (x, y)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' f" := - (expr_let n := Z.cast _ @@ (Z.shiftl count @@ (Z.cast uint256 @@ (Z.mul @@ (x, y)))) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : expr_scope. - Notation "'c.Add256(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Add128(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Add64(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast uint128 @@ (Z.add @@ (x, y)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Addc256(' '$' n ',' x ',' y ',' z ');' f" := - (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc256(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. - Notation "'c.Addc128(' '$' n ',' x ',' y ',' z ');' f" := - (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc128(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. - Notation "'c.Selc(' '$' n ',' x ',' y ',' z ');' f" := - (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (x , y, z)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. - Notation "'c.Selm(' '$' n ',' x ',' y ',' z ');' f" := - (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0 ~>1]) @@ ((Z.cc_m_concrete _) @@ x), y, z)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selm(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. - Notation "'c.Sell(' '$' n ',' x ',' y ',' z ');' f" := - (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0~>1]) @@ (Z.land 1 @@ _ x), y, z)) in - f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Sell(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope. - Notation "'c.Sub(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in - f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope. - Notation "'c.Sub(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)) in - f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope. - Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := - (Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope. - Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast _ @@ (Z.shiftr y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Rshi(' '$' n ',' x ',' y ',' m ');' f" := - (expr_let n := Z.cast _ @@ (Z.rshi_concrete $R m @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Rshi(' '$' n ',' x ',' y ',' m ');' ']' '//' f") : expr_scope. - Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" := - (expr_let n := Z.cast _ @@ (Z.shiftl y @@ (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x))) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. - Notation "'c.Lower128(' '$' n ',' x ');' f" := - (expr_let n := Z.cast _ @@ (Z.land 340282366920938463463374607431768211455 @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$' n ',' x ');' ']' '//' f") : expr_scope. - Notation "'c.Lower(' x ')'" - := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x)) - (at level 10, only printing, format "c.Lower( x )") - : expr_scope. - Notation "'c.Upper(' x ')'" - := (Z.cast uint128 @@ (Z.shiftr 128 @@ x)) - (at level 10, only printing, format "c.Upper( x )") - : expr_scope. - Notation "( v << count )" - := (Z.cast _ @@ (Z.shiftl count @@ v)%expr) - (format "( v << count )") - : expr_scope. - (* - Notation "( x >> count )" - := (Z.cast _ @@ (Z.shiftr count @@ x)%expr) - (format "( x >> count )") - : expr_scope. - Notation "x * y" - := (Z.cast uint256 @@ (Z.mul @@ (x, y))) - : expr_scope. - *) -End FancyPrintingNotations. - - Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont"). Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont"). Import Fancy.Registers. Import Fancy. + +Import Barrett256 Montgomery256. + +(*** Montgomery Reduction ***) + +(* Status: Code in final form is proven correct modulo admits in compiler portions. *) + +(* Montgomery Code : *) Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.MontRed256. (* = fun lo hi y t1 t2 scratch RegPInv : register => @@ -12330,46 +12515,107 @@ Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.M ADDM lo lo RegZero RegMod; Ret lo *) -Import Montgomery256. + +(* Uncomment to see proof statement and remaining admitted statements, +or search for "prod_montred256_correct" to see comments on the proof +preconditions. *) +(* Check Montgomery256.prod_montred256_correct. -(* Print Assumptions Montgomery256.prod_montred256_correct. *) +Print Assumptions Montgomery256.prod_montred256_correct. +*) + +(*** Barrett Reduction ***) -Import FancyPrintingNotations. -Local Open Scope expr_scope. +(* Status : Code in "pre-fancy" (stage before final form) is proven +correct modulo admits in compiler portions and one additional +admit (see comment in barrett_red256_prefancy_correct). Code in final +form ("fancy") is generated, but not yet proven equivalent to the +"pre-fancy" version. + +The next step is to, using the same methodology as for Montgomery, +prove the generated fancy code equivalent to Prod.MulMod and also +equivalent to the "pre-fancy" code. Once these equivalences are +proven, then the "pre-fancy" proof will apply to Prod.MulMod. + *) + +Import PreFancy. +Import PreFancy.Notations. +Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). +Local Notation "'RegMuLow'" := (Straightline.expr.Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835). + +(* "Prefancy" form : *) +Print Barrett256.barrett_red256_prefancy. +(* +selm@(y, $x₂, RegZero, RegMuLow); +rshi@(y0, RegZero, $x₂,255); +rshi@(y1, $x₂, $x₁,255); +mulhh@(y2, RegMuLow, $y1); +mulhl@(y3, RegMuLow, $y1); +mullh@(y4, RegMuLow, $y1); +mulll@(y5, RegMuLow, $y1); +add@(y6, $y5, $y4, 128); +addc@(y7, carry{$y6}, $y2, $y4, -128); +add@(y8, $y6, $y3, 128); +addc@(y9, carry{$y8}, $y7, $y3, -128); +add@(y10, $y1, $y9, 0); +addc@(y11, carry{$y10}, RegZero, $y0, 0); +add@(y12, $y, $y10, 0); +addc@(y13, carry{$y12}, RegZero, $y11, 0); +rshi@(y14, $y13, $y12,1); +mulhh@(y15, RegMod, $y14); +mullh@(y16, RegMod, $y14); +mulhl@(y17, RegMod, $y14); +mulll@(y18, RegMod, $y14); +add@(y19, $y18, $y17, 128); +addc@(y20, carry{$y19}, $y15, $y17, -128); +add@(y21, $y19, $y16, 128); +addc@(y22, carry{$y21}, $y20, $y16, -128); +sub@(y23, $x₁, $y21, 0); +subb@(y24, carry{$y23}, $x₂, $y22, 0); +sell@(y25, $y24, RegZero, RegMod); +sub@(y26, $y23, $y25, 0); +addm@(y27, $y26, RegZero, RegMod); +ret $y27 + *) + +(* Uncomment to see proof statement and remaining admitted statements. *) +(* +Check barrett_red256_prefancy_correct. +Print Assumptions barrett_red256_prefancy_correct. + *) -Print Barrett256.barrett_red256. +(* "Fancy" code (NOT proven) *) +Eval cbv beta iota delta [barrett_red256_alloc] in barrett_red256_alloc. (* -c.Selm($x0, $x_hi, RegZero, RegMuLow); -c.Rshi($x1, RegZero, $x_hi, 255); -c.Rshi($x2, $x_hi, $x_lo, 255); -c.Mul128x128($x3, c.Upper(RegMuLow), c.Upper($x2)); -c.Mul128x128($x4, c.Upper(RegMuLow), c.Lower($x2)); -c.Mul128x128($x5, c.Lower(RegMuLow), c.Upper($x2)); -c.Mul128x128($x6, c.Lower(RegMuLow), c.Lower($x2)); -c.Add256($x7, (c.Lower($x4) << 128), $x6); -c.Addc256($x8, $x7_hi, $x3, c.Upper($x5)); -c.Add256($x9, (c.Lower($x5) << 128), $x7_lo); -c.Addc256($x10, $x9_hi, c.Upper($x4), $x8_lo); -c.Add256($x11, $x2, $x10_lo); -c.Addc128($x12, $x11_hi, RegZero, $x1); -c.Add256($x13, $x0, $x11_lo); -c.Addc128($x14, $x13_hi, RegZero, $x12_lo); -c.Rshi($x15, $x14_lo, $x13_lo, 1); -c.Mul128x128($x16, c.Upper(RegMod), c.Upper($x15)); -c.Mul128x128($x17, c.Lower(RegMod), c.Upper($x15)); -c.Mul128x128($x18, c.Upper(RegMod), c.Lower($x15)); -c.Mul128x128($x19, c.Lower(RegMod), c.Lower($x15)); -c.Add256($x20, (c.Lower($x17) << 128), $x19); -c.Addc256($x21, $x20_hi, $x16, c.Upper($x18)); -c.Add256($x22, (c.Lower($x18) << 128), $x20_lo); -c.Addc256($x23, $x22_hi, c.Upper($x17), $x21_lo); -expr_let x24 := Z.add_get_carry_concrete - 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ - (Z.opp @@ $x22_lo, $x_lo) in -expr_let x25 := Z.add_with_get_carry_concrete - 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@ - ($x24_hi, Z.opp @@ $x23_lo, $x_hi) in -c.Sell($x26, $x25_lo, RegZero, RegMod); -c.Sub($x27, $x24_lo, $x26); -c.AddM($ret, $x27, RegZero, RegMod); + = fun (xLow xHigh RegMuLow : register) (_ : positive) (_ : register) => + SELM r2 RegMuLow RegZero; + RSHI 255 r3 RegZero xHigh; + RSHI 255 r4 xHigh xLow; + MUL128UU r5 RegMuLow r4; + MUL128UL r6 r4 RegMuLow; + MUL128LU r7 r4 RegMuLow; + MUL128LL r8 RegMuLow r4; + ADD 128 r9 r8 r7; + ADDC (-128) r10 r5 r7; + ADD 128 r5 r9 r6; + ADDC (-128) r11 r10 r6; + ADD 0 r6 r4 r11; + ADDC 0 r12 RegZero r3; + ADD 0 r13 r2 r6; + ADDC 0 r14 RegZero r12; + RSHI 1 r15 r14 r13; + MUL128UU r16 RegMod r15; + MUL128LU r17 r15 RegMod; + MUL128UL r18 r15 RegMod; + MUL128LL r19 RegMod r15; + ADD 128 r20 r19 r18; + ADDC (-128) r21 r16 r18; + ADD 128 r22 r20 r17; + ADDC (-128) r23 r21 r17; + SUB 0 r24 xLow r22; + SUBC 0 r25 xHigh r23; + SELL r26 RegMod RegZero; + SUB 0 r27 r24 r26; + ADDM r28 r27 RegZero RegMod; + Ret r28 *) -- cgit v1.2.3