From d6ea917674ca7475a15a98ecfc1ff7259b8dbba9 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Fri, 11 May 2018 17:22:02 +0200 Subject: end-to-end proof for montgomery --- src/Experiments/SimplyTypedArithmetic.v | 474 +++++++++++++++++++++++++------- 1 file changed, 382 insertions(+), 92 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 4a5be2505..8f4cfb6b5 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -8226,12 +8226,10 @@ Module Straightline. | ok_scalar_cast : forall r s (idc : ident s _) x, ok_scalar x -> - ok_scalar_ident idc -> @ok_expr type.Z (AppIdent (ident.Z.cast r) (AppIdent idc x)) | ok_scalar_cast2 : forall r s (idc : ident s _) x, ok_scalar x -> - ok_scalar_ident idc -> @ok_expr (type.prod type.Z type.Z) (AppIdent (ident.Z.cast2 r) (AppIdent idc x)) | ok_scalar_nocast : forall t x, @ok_scalar t x -> @ok_expr t x @@ -8385,11 +8383,12 @@ Module Straightline. Admitted. End proofs. End expr. - + Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_arrow: expr.expr d := match invert_Abs (e var) with - | Some f => expr.of_uncurried (dummy_arrow:=dummy_arrow) (expr.depth (fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x) + | Some f => + expr.of_uncurried (dummy_arrow:=dummy_arrow) (expr.depth (fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x) | None => expr.dummy (dummy_arrow:=dummy_arrow) d end. @@ -8466,7 +8465,7 @@ End StraightlineTest. Module PreFancy. Import Straightline.expr. Section with_wordmax. - Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax). + Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax) (log2wordmax_even : log2wordmax mod 2 = 0). Let wordmax := 2 ^ log2wordmax. Let half_bits := log2wordmax / 2. Let wordmax_half_bits := 2 ^ half_bits. @@ -8487,6 +8486,9 @@ Module PreFancy. Lemma wordmax_half_bits_pos : 0 < wordmax_half_bits. Proof. subst wordmax_half_bits half_bits. Z.zero_bounds. Qed. + Lemma half_bits_nonneg : 0 <= half_bits. + Proof. subst half_bits; Z.zero_bounds. Qed. + Lemma half_bits_squared : (wordmax_half_bits - 1) * (wordmax_half_bits - 1) <= wordmax - 1. Proof. pose proof wordmax_half_bits_pos. @@ -8504,8 +8506,7 @@ Module PreFancy. Qed. Section with_var. - Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) - (constant_to_scalar : Z -> option (@scalar var type.Z)). + Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (consts : list Z). Local Notation Z := (type.type_primitive type.Z). Inductive ident : type -> type -> Type := @@ -8527,6 +8528,20 @@ Module PreFancy. . Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t). + Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) := + if x =? (BinInt.Z.shiftr const half_bits) + then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const))) + else if x =? (BinInt.Z.land const (wordmax_half_bits - 1)) + then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const))) + else None. + + Definition constant_to_scalar (x : BinInt.Z) + : option (Straightline.expr.scalar Z) := + fold_right (fun c res => match res with + | Some s => Some s + | None => constant_to_scalar_single c x + end) None consts. + Definition invert_lower' {t} (e : @scalar var t) : option (@scalar var Z) := match e in scalar t return option (@scalar var Z) with @@ -8669,20 +8684,6 @@ Module PreFancy. | LetInAppIdentZZ _ t r idc x f => of_straightline_ident idc t r x (fun y => of_straightline (f y)) end. - - Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) := - if x =? (BinInt.Z.shiftr const half_bits) - then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const))) - else if x =? (BinInt.Z.land const (wordmax_half_bits - 1)) - then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const))) - else None. - - Definition constant_to_scalar_gen (consts : list BinInt.Z) (x : BinInt.Z) - : option (Straightline.expr.scalar Z) := - fold_right (fun c res => match res with - | Some s => Some s - | None => constant_to_scalar_single c x - end) None consts. End with_var. Section interp. @@ -8719,8 +8720,8 @@ Module PreFancy. End interp. Section proofs. - Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) - (constant_to_scalar : Z -> option (@scalar type.interp type.Z)). + Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) + (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). Local Notation word_range := (r[0~>wordmax-1])%zrange. Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange. @@ -8799,7 +8800,22 @@ Module PreFancy. | sc_ok_prim : forall p x, ok_scalar (@Primitive _ p x) . - Print ident. + Inductive is_halved : scalar type.Z -> Prop := + | is_halved_lower : + forall x : scalar type.Z, + in_word_range (get_range x) -> + is_halved (Cast half_word_range (Land (wordmax_half_bits - 1) x)) + | is_halved_upper : + forall x : scalar type.Z, + in_word_range (get_range x) -> + is_halved (Cast half_word_range (Shiftr half_bits x)) + | is_halved_constant : + forall y z, + constant_to_scalar consts z = Some y -> + is_halved y -> + is_halved (Primitive (t:=type.Z) z) + . + Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop := | ok_add : forall x : scalar (type.prod type.Z type.Z), @@ -8835,14 +8851,16 @@ Module PreFancy. 0 <= n < wordmax -> ok_ident type.Z type.Z x (ZRange.map (fun y => Z.land y n) (get_range x)) (ident.Z.land n)*) | ok_shiftr : - forall (x : scalar type.Z) n, + forall (x : scalar type.Z) n r, in_word_range (get_range x) -> 0 <= n <= log2wordmax -> - ok_ident type.Z type.Z x (ZRange.map (fun y => Z.shiftr y n) (get_range x)) (ident.Z.shiftr n) + r = ZRange.map (fun y => Z.shiftr y n) (get_range x) -> + ok_ident type.Z type.Z x r (ident.Z.shiftr n) | ok_shiftl : forall (x : scalar type.Z) n, in_word_range (get_range x) -> - 0 <= n < log2wordmax - Z.log2_up (upper (get_range x)) -> + 0 <= n -> + upper (get_range x) * 2 ^ n <= wordmax - 1 -> ok_ident type.Z type.Z x word_range (ident.Z.shiftl n) | ok_rshi : forall (x : scalar (type.prod type.Z type.Z)) n, @@ -8857,7 +8875,7 @@ Module PreFancy. in_word_range (get_range z) -> ok_ident _ type.Z - (Pair (Pair (Snd x) y) z) + (Pair (Pair (Cast flag_range (Snd x)) y) z) word_range ident.Z.zselect | ok_selm : @@ -8891,6 +8909,16 @@ Module PreFancy. x word_range ident.Z.add_modulo + | ok_mul : + forall x y : scalar type.Z, + is_halved x -> + is_halved y -> + ok_ident (type.prod type.Z type.Z) + type.Z + (Pair x y) + word_range + ident.Z.mul + (* | ok_mulll : forall x y : scalar type.Z, in_word_range (get_range x) -> @@ -8935,6 +8963,7 @@ Module PreFancy. (Cast half_word_range (Shiftr half_bits y))) word_range ident.Z.mul +*) . Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop := @@ -9061,7 +9090,7 @@ Module PreFancy. Qed. Lemma has_word_range_shiftl n r x : - 0 <= n < log2wordmax - Z.log2_up (upper r) -> + 0 <= n -> upper r * 2 ^ n <= wordmax - 1 -> @has_range type.Z r x -> in_word_range r -> @has_range type.Z word_range (x << n). @@ -9073,18 +9102,7 @@ Module PreFancy. apply andb_true_iff; split; apply Z.leb_le; [ apply Z.shiftl_nonneg; solve [auto using in_word_range_nonneg] | ]. rewrite Z.shiftl_mul_pow2 by omega. - destruct (dec (upper r = 0)); - [ match goal with H : _ = 0 |- _ => rewrite H end; pose proof wordmax_gt_2; lia | ]. - match goal with |- ?a * ?b <= _ => - transitivity (2 ^ (Z.log2_up a) * b) end. - { apply Z.mul_le_mono_nonneg_r; auto with zarith; [ ]. - apply Z.log2_log2_up_spec. - apply Z.le_neq; eauto using in_word_range_upper_nonneg with has_range. } - { rewrite <-Z.pow_add_r by auto with zarith. - assert (2 ^ (Z.log2_up (upper r) + n) < 2 ^ log2wordmax) - by (apply Z.pow_lt_mono_r; omega). - replace wordmax with (2 ^ log2wordmax) by reflexivity. - omega. } + auto. Qed. Lemma has_word_range_rshi n x y : @@ -9098,12 +9116,12 @@ Module PreFancy. Qed. Lemma in_word_range_spec r : - 0 <= lower r -> upper r <= wordmax - 1 -> - in_word_range r. + (0 <= lower r /\ upper r <= wordmax - 1) + <-> in_word_range r. Proof. intros; cbv [in_word_range is_tighter_than_bool]. - apply andb_true_iff. - split; apply Z.leb_le; cbn [upper lower]; omega. + rewrite andb_true_iff. + intuition; apply Z.leb_le; cbn [upper lower]; try omega. Qed. Ltac destruct_scalar := @@ -9119,6 +9137,150 @@ Module PreFancy. end end. + Ltac extract_ok_scalar' level x := + match goal with + | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ => + match (eval compute in (4 <=? level)) with + | true => invert H; extract_ok_scalar' 3 x + | _ => fail + end + | H : ok_scalar (Pair (?f (?g x)) _) |- _ => + match (eval compute in (3 <=? level)) with + | true => invert H; extract_ok_scalar' 2 x + | _ => fail + end + | H : ok_scalar (Pair _ (?f (?g x))) |- _ => + match (eval compute in (3 <=? level)) with + | true => invert H; extract_ok_scalar' 2 x + | _ => fail + end + | H : ok_scalar (?f (?g x)) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end + | H : ok_scalar (?g x) |- _ => invert H + | H : ok_scalar (Pair x _) |- _ => invert H + | H : ok_scalar (Pair _ x) |- _ => invert H + end. + + Ltac extract_ok_scalar := + match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end. + + Lemma has_half_word_range_shiftr r x : + in_word_range r -> + @has_range type.Z r x -> + @has_range type.Z half_word_range (x >> half_bits). + Proof. + cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff. + cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt. + { apply Z.shiftr_nonneg. omega. } + { pose proof half_bits_nonneg. + pose proof half_bits_squared. + assert (x >> half_bits < wordmax_half_bits); [|omega]. + rewrite Z.shiftr_div_pow2 by auto. + apply Z.div_lt_upper_bound; Z.zero_bounds. + subst wordmax_half_bits half_bits. + rewrite <-Z.pow_add_r by omega. + rewrite Z.add_diag, Z.mul_div_eq, log2wordmax_even by omega. + autorewrite with zsimplify_fast. subst wordmax. omega. } + Qed. + + Lemma has_half_word_range_land r x : + in_word_range r -> + @has_range type.Z r x -> + @has_range type.Z half_word_range (x &' (wordmax_half_bits - 1)). + Proof. + pose proof wordmax_half_bits_pos. + cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff. + cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt. + { apply Z.land_nonneg; omega. } + { apply Z.land_upper_bound_r; omega. } + Qed. + + Section constant_to_scalar. + Lemma constant_to_scalar_single_correct s x z : + 0 <= x <= wordmax - 1 -> + constant_to_scalar_single x z = Some s -> interp_scalar s = z. + Proof. + cbv [constant_to_scalar_single]. + break_match; try discriminate; intros; Z.ltb_to_lt; subst; + try match goal with H : Some _ = Some _ |- _ => inversion H; subst end; + cbn [interp_scalar]; apply interp_cast_noop. + { apply has_half_word_range_shiftr with (r:=r[x~>x]%zrange); + cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega. + apply andb_true_iff; split; apply Z.leb_le; omega. } + { apply has_half_word_range_land with (r:=r[x~>x]%zrange); + cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega. + apply andb_true_iff; split; apply Z.leb_le; omega. } + Qed. + + Lemma constant_to_scalar_correct s z : + constant_to_scalar consts z = Some s -> interp_scalar s = z. + Proof. + cbv [constant_to_scalar]. + apply fold_right_invariant; try discriminate. + intros until 2; break_match; eauto using constant_to_scalar_single_correct. + Qed. + + Lemma constant_to_scalar_single_cases x y z : + @constant_to_scalar_single type.interp x z = Some y -> + (y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive (t:=type.Z) x))) + \/ (y = Cast half_word_range (Shiftr half_bits (Primitive (t:=type.Z) x))). + Proof. + cbv [constant_to_scalar_single]. + break_match; try discriminate; intros; Z.ltb_to_lt; subst; + try match goal with H : Some _ = Some _ |- _ => inversion H; subst end; + tauto. + Qed. + + Lemma constant_to_scalar_cases y z : + @constant_to_scalar type.interp consts z = Some y -> + (exists x, + @has_range type.Z word_range x + /\ y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive x))) + \/ (exists x, + @has_range type.Z word_range x + /\ y = Cast half_word_range (Shiftr half_bits (Primitive x))). + Proof. + cbv [constant_to_scalar]. + apply fold_right_invariant; try discriminate. + intros until 2; break_match; eauto; intros. + match goal with H : constant_to_scalar_single _ _ = _ |- _ => + destruct (constant_to_scalar_single_cases _ _ _ H); subst end. + { left; eexists; split; eauto. + apply consts_ok; auto. } + { right; eexists; split; eauto. + apply consts_ok; auto. } + Qed. + + Lemma ok_scalar_constant_to_scalar y z : constant_to_scalar consts z = Some y -> ok_scalar y. + Proof. + pose proof wordmax_half_bits_pos. pose proof half_bits_nonneg. + let H := fresh in + intro H; apply constant_to_scalar_cases in H; destruct H as [ [? ?] | [? ?] ]; intuition; subst; + cbn [has_range lower upper] in *; repeat constructor; cbn [lower get_range]; try apply Z.leb_refl; try omega. + assert (in_word_range r[x~>x]) by (apply in_word_range_spec; cbn [lower upper]; omega). + pose proof (has_half_word_range_shiftr r[x~>x] x ltac:(assumption) ltac:(cbv [has_range lower upper]; omega)). + cbn [has_range ZRange.map is_tighter_than_bool lower upper] in *. + apply andb_true_iff; cbn [lower upper]; split; apply Z.leb_le; omega. + Qed. + End constant_to_scalar. + Hint Resolve ok_scalar_constant_to_scalar. + + Lemma is_halved_has_range x : + ok_scalar x -> + is_halved x -> + @has_range type.Z half_word_range (interp_scalar x). + Proof. + intro; pose proof (has_range_interp_scalar x ltac:(assumption)). + induction 1; cbn [interp_scalar] in *; intros; try assumption; [ ]. + rewrite <-(constant_to_scalar_correct y z) by assumption. + eauto using has_range_interp_scalar. + Qed. + Lemma ident_interp_has_range s d x r idc: ok_scalar x -> ok_ident s d x r idc -> @@ -9131,6 +9293,7 @@ Module PreFancy. repeat match goal with | H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt | H : _ /\ _ |- _ => destruct H + | H : is_halved _ |- _ => apply is_halved_has_range in H; [ | extract_ok_scalar ] | _ => progress subst | _ => progress (cbv [in_word_range in_flag_range is_tighter_than_bool] in * ) | _ => progress (cbn [interp_scalar get_range has_range upper lower fst snd] in * ) @@ -9151,7 +9314,8 @@ Module PreFancy. rewrite Z.div_sub_small by omega. split; break_match; lia. } { apply has_range_shiftr; cbn [has_range]; omega. } - { eapply has_word_range_shiftl; eauto using in_word_range_spec with has_range omega. } + { eapply has_word_range_shiftl; eauto with has_range omega. + apply in_word_range_spec ; omega. } { apply has_word_range_rshi; omega. } { rewrite Z.zselect_correct. break_match; omega. } { cbn [interp_scalar fst snd get_range] in *. @@ -9160,40 +9324,10 @@ Module PreFancy. rewrite Z.zselect_correct. break_match; omega. } { rewrite Z.add_modulo_correct. break_match; Z.ltb_to_lt; omega. } - { cbn [interp_scalar fst snd get_range upper lower] in *. - pose proof half_bits_squared. nia. } - { cbn [interp_scalar fst snd get_range upper lower] in *. - pose proof half_bits_squared. nia. } - { cbn [interp_scalar fst snd get_range upper lower] in *. - pose proof half_bits_squared. nia. } - { cbn [interp_scalar fst snd get_range upper lower] in *. + { cbn [interp_scalar has_range fst snd get_range upper lower] in *. pose proof half_bits_squared. nia. } Qed. - Ltac extract_ok_scalar' level x := - match goal with - | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ => - match (eval compute in (4 <=? level)) with - | true => invert H; extract_ok_scalar' 3 x - | _ => fail - end - | H : ok_scalar (Pair (?f (?g x)) _) |- _ => - match (eval compute in (3 <=? level)) with - | true => invert H; extract_ok_scalar' 2 x - | _ => fail - end - | H : ok_scalar (?f (?g x)) |- _ => - match (eval compute in (2 <=? level)) with - | true => invert H; extract_ok_scalar' 1 x - | _ => fail - end - | H : ok_scalar (?g x) |- _ => invert H - end. - - - Ltac extract_ok_scalar := - match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end. - Lemma has_flag_range_cc_m r x : @has_range type.Z r x -> in_word_range r -> @@ -9239,16 +9373,76 @@ Module PreFancy. | _ => rewrite interp_cast_noop by assumption end. + Lemma is_halved_cases x : + is_halved x -> + ok_scalar x -> + (exists y, + invert_lower consts x = Some y + /\ invert_upper consts x = None + /\ interp_scalar y &' (wordmax_half_bits - 1) = interp_scalar x) + \/ (exists y, + invert_lower consts x = None + /\ invert_upper consts x = Some y + /\ interp_scalar y >> half_bits = interp_scalar x). + Proof. + induction 1; intros; cbn; rewrite ?Z.eqb_refl; cbn. + { left. eexists; repeat split; auto. + rewrite interp_cast_noop; [ reflexivity | ]. + apply has_half_word_range_land with (r:=get_range x); auto. + apply has_range_interp_scalar; extract_ok_scalar. } + { right. eexists; repeat split; auto. + rewrite interp_cast_noop; [ reflexivity | ]. + apply has_half_word_range_shiftr with (r:=get_range x); auto. + apply has_range_interp_scalar; extract_ok_scalar. } + { match goal with H : constant_to_scalar _ _ = Some _ |- _ => + rewrite H; + let P := fresh in + destruct (constant_to_scalar_cases _ _ H) as [ [? [? ?] ] | [? [? ?] ] ]; + subst; cbn; rewrite ?Z.eqb_refl; cbn + end. + { left; eexists; repeat split; auto. + erewrite <-constant_to_scalar_correct by eassumption. + subst. cbn. + rewrite interp_cast_noop; [ reflexivity | ]. + eapply has_half_word_range_land with (r:=word_range); auto. + cbv [in_word_range is_tighter_than_bool]. + rewrite !Z.leb_refl; reflexivity. } + { right; eexists; repeat split; auto. + erewrite <-constant_to_scalar_correct by eassumption. + subst. cbn. + rewrite interp_cast_noop; [ reflexivity | ]. + eapply has_half_word_range_shiftr with (r:=word_range); auto. + cbv [in_word_range is_tighter_than_bool]. + rewrite !Z.leb_refl; reflexivity. } } + Qed. + + Lemma of_straightline_ident_mul_correct t x y g : + is_halved x -> + is_halved y -> + ok_scalar (Pair x y) -> + @has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) -> + interp (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) = + interp (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))). + Proof. + intros Hx Hy Hok ?; invert Hok; cbn [interp_scalar of_straightline_ident]. + destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ]; + rewrite ?Pxlow, ?Pxhigh; + destruct (is_halved_cases y Hy ltac:(assumption)) as [ [? [Pylow [Pyhigh Pyi] ] ] | [? [Pylow [Pyhigh Pyi] ] ] ]; + rewrite ?Pylow, ?Pyhigh; + cbn; rewrite Pxi, Pyi; rewrite interp_cast_noop by auto; reflexivity. + Qed. + Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g : ok_ident s d x r idc -> ok_scalar x -> - interp (of_straightline_ident dummy_arrow constant_to_scalar idc t r x g) = + interp (of_straightline_ident dummy_arrow consts idc t r x g) = interp (g (ident.interp idc (interp_scalar x))). Proof. intros. pose proof wordmax_half_bits_pos. pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)). - induction H; cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell invert_upper invert_lower invert_upper' invert_lower'] in *; intros; + induction H; try solve [auto using of_straightline_ident_mul_correct]; + cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell] in *; intros; rewrite ?Z.eqb_refl; cbn [interp interp_ident andb]; try destruct_scalar; repeat match goal with | _ => progress (cbn [fst snd interp_scalar] in * ) @@ -9261,16 +9455,12 @@ Module PreFancy. | _ => rewrite interp_cast_noop by assumption | _ => rewrite interp_cast2_noop by assumption | _ => reflexivity - end; [ | | | ]; (* leftover cases are the 4 kinds of multiplication *) - match goal with - | H1 : in_word_range (get_range ?x), H2 : in_word_range (get_range ?y) |- _ => - extract_ok_scalar' 4 x; extract_ok_scalar' 4 y - end; rewrite_cast_noop_in_mul; reflexivity. + end. Qed. Lemma of_straightline_correct {t} (e : expr t) : ok_expr e -> - interp (of_straightline dummy_arrow constant_to_scalar e) = Straightline.expr.interp (interp_ident:=@ident.interp) e. + interp (of_straightline dummy_arrow consts e) = Straightline.expr.interp (interp_ident:=@ident.interp) e. Proof. induction 1; cbn [of_straightline]; intros; repeat match goal with @@ -9287,16 +9477,19 @@ Module PreFancy. Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) (var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d := - @of_straightline log2wordmax var dummy_arrow (@constant_to_scalar_gen log2wordmax var consts) _ (Straightline.of_Expr e var x dummy_arrow). + @of_straightline log2wordmax var dummy_arrow consts _ (Straightline.of_Expr e var x dummy_arrow). Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) (e' : (type.interp s -> Uncurried.expr.expr d)) (x : type.interp s) dummy_arrow : e type.interp = Abs e' -> 1 < log2wordmax -> + log2wordmax mod 2 = 0 -> Straightline.expr.ok_expr (e' x) -> - ok_expr log2wordmax (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit))) (e' x)) -> - (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit)))%nat -> + (forall x0 : Z, In x0 consts -> 0 <= x0 <= 2 ^ log2wordmax - 1) -> + ok_expr log2wordmax consts + (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e _)) (e' x)) -> + (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e _))%nat -> interp log2wordmax (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x. Proof. intro He'; intros; cbv [of_Expr Straightline.of_Expr]. @@ -10542,9 +10735,9 @@ Ltac solve_rmontred_nocache := solve_rop_nocache MontgomeryReduction.rmontred_co Module Montgomery256. - Definition N := (2^256-2^224+2^192+2^96-1). + Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1). Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849). - Definition R := (2^256). + Definition R := Eval lazy in (2^256). Definition machine_wordsize := 256. Derive montred256 @@ -10552,6 +10745,104 @@ Module Montgomery256. As montred256_correct. Proof. Time solve_rmontred machine_wordsize. Time Qed. + + Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 (lo, hi) = ((lo + R * hi) * R') mod N. + Proof. + intros. + apply MontgomeryReduction.montred'_correct with (T:=lo + R * hi) (R':=R'); + try match goal with + | |- context[R'] => assumption + | |- context [lo] => + try assumption; progress autorewrite with zsimplify cancel_pair; reflexivity + end; lazy; try split; congruence. + Qed. + + (* Note: If this is not factored out, then for some reason Qed takes forever in montred256_correct_full. *) + Lemma montred256_correct_proj2 : + forall xy : type.interp (type.prod type.Z type.Z), + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + expr.Interp (@ident.interp) montred256 xy = MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 xy. + Proof. intros; destruct (montred256_correct xy); assumption. Qed. + + Lemma montred256_correct_full R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + expr.interp (@ident.interp) (montred256 type.interp) (lo, hi) = ((lo + R * hi) * R') mod N. + Proof. + intros. + rewrite <-montred'_correct_specialized by assumption. + rewrite <-montred256_correct_proj2. + { cbv [expr.Interp type.uncurried_domain type.uncurry type.final_codomain]. + reflexivity. } + { cbn. rewrite !andb_true_iff. cbv [R N] in *. + repeat split; apply Z.leb_le; omega. } + Qed. + + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step' := + match goal with + | _ => assumption + | |- context [PreFancy.ok_ident] => constructor + | |- context [PreFancy.ok_scalar] => constructor; try omega + | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] + | |- context [PreFancy.is_halved] => constructor + | |- context [PreFancy.in_word_range] => lazy; reflexivity + | |- context [PreFancy.in_flag_range] => lazy; reflexivity + | |- context [PreFancy.get_range] => + cbn [PreFancy.get_range lower upper fst snd ZRange.map] + | x : type.interp (type.prod _ _) |- _ => destruct x + | |- (_ <=? _)%zrange = true => + match goal with + | |- context [PreFancy.get_range_var] => + cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower R N] in *; cbn; + apply andb_true_iff; split; apply Z.leb_le + | _ => lazy + end; omega || reflexivity + | |- @eq zrange _ _ => lazy; reflexivity + | |- _ <= _ => cbv [machine_wordsize]; omega + | |- _ <= _ <= _ => cbv [machine_wordsize]; omega + end; intros. + + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step := + match goal with + | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' + end; intros; cbn [Nat.max]. + + Definition montred256_prefancy' := PreFancy.of_Expr machine_wordsize [N;N'] montred256. + + Derive montred256_prefancy + SuchThat (montred256_prefancy = montred256_prefancy' type.interp) + As montred256_prefancy_eq. + Proof. lazy - [type.interp]; reflexivity. Qed. + + Lemma montred256_prefancy_correct R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z) dummy_arrow, + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + @PreFancy.interp machine_wordsize type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N. + Proof. + intros. rewrite montred256_prefancy_eq; cbv [montred256_prefancy']. + erewrite PreFancy.of_Expr_correct. + { apply montred256_correct_full; assumption. } + { reflexivity. } + { lazy; reflexivity. } + { lazy; reflexivity. } + { repeat constructor. } + { cbv [In N N']; intros; intuition; subst; cbv; congruence. } + { assert (340282366920938463463374607431768211455 * 2 ^ 128 <= 2 ^ machine_wordsize - 1) as shiftl_128_ok by (lazy; congruence). + repeat (ok_expr_step; [ ]). + ok_expr_step. + lazy; congruence. + constructor. } + { lazy. omega. } + Qed. + Import PrintingNotations. Set Printing Width 10000. @@ -10585,8 +10876,7 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) *) - Definition montred256_prefancy := - Eval lazy in (PreFancy.of_Expr machine_wordsize [N;N'] montred256). + Import PreFancy. Import PreFancy.Notations. Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). -- cgit v1.2.3