aboutsummaryrefslogtreecommitdiff
path: root/src/Experiments/SimplyTypedArithmetic.v
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-05-11 17:22:02 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-05-31 13:46:48 +0200
commitd6ea917674ca7475a15a98ecfc1ff7259b8dbba9 (patch)
tree311290bad6bb503c2401f97bd327b33c02d56f2f /src/Experiments/SimplyTypedArithmetic.v
parente6b25ccf1bb99582c3f83aa8b3b4fe3f0de31870 (diff)
end-to-end proof for montgomery
Diffstat (limited to 'src/Experiments/SimplyTypedArithmetic.v')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v474
1 files changed, 382 insertions, 92 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 4a5be2505..8f4cfb6b5 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -8226,12 +8226,10 @@ Module Straightline.
| ok_scalar_cast :
forall r s (idc : ident s _) x,
ok_scalar x ->
- ok_scalar_ident idc ->
@ok_expr type.Z (AppIdent (ident.Z.cast r) (AppIdent idc x))
| ok_scalar_cast2 :
forall r s (idc : ident s _) x,
ok_scalar x ->
- ok_scalar_ident idc ->
@ok_expr (type.prod type.Z type.Z) (AppIdent (ident.Z.cast2 r) (AppIdent idc x))
| ok_scalar_nocast :
forall t x, @ok_scalar t x -> @ok_expr t x
@@ -8385,11 +8383,12 @@ Module Straightline.
Admitted.
End proofs.
End expr.
-
+
Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_arrow: expr.expr d
:=
match invert_Abs (e var) with
- | Some f => expr.of_uncurried (dummy_arrow:=dummy_arrow) (expr.depth (fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x)
+ | Some f =>
+ expr.of_uncurried (dummy_arrow:=dummy_arrow) (expr.depth (fun _ => unit) (fun _ => tt) (e (fun _ => unit))) (f x)
| None => expr.dummy (dummy_arrow:=dummy_arrow) d
end.
@@ -8466,7 +8465,7 @@ End StraightlineTest.
Module PreFancy.
Import Straightline.expr.
Section with_wordmax.
- Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax).
+ Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax) (log2wordmax_even : log2wordmax mod 2 = 0).
Let wordmax := 2 ^ log2wordmax.
Let half_bits := log2wordmax / 2.
Let wordmax_half_bits := 2 ^ half_bits.
@@ -8487,6 +8486,9 @@ Module PreFancy.
Lemma wordmax_half_bits_pos : 0 < wordmax_half_bits.
Proof. subst wordmax_half_bits half_bits. Z.zero_bounds. Qed.
+ Lemma half_bits_nonneg : 0 <= half_bits.
+ Proof. subst half_bits; Z.zero_bounds. Qed.
+
Lemma half_bits_squared : (wordmax_half_bits - 1) * (wordmax_half_bits - 1) <= wordmax - 1.
Proof.
pose proof wordmax_half_bits_pos.
@@ -8504,8 +8506,7 @@ Module PreFancy.
Qed.
Section with_var.
- Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d))
- (constant_to_scalar : Z -> option (@scalar var type.Z)).
+ Context {var : type -> Type} (dummy_arrow : forall s d, var (type.arrow s d)) (consts : list Z).
Local Notation Z := (type.type_primitive type.Z).
Inductive ident : type -> type -> Type :=
@@ -8527,6 +8528,20 @@ Module PreFancy.
.
Definition dummy t : @expr var ident t := Scalar (dummy_scalar (dummy_arrow:=dummy_arrow) t).
+ Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) :=
+ if x =? (BinInt.Z.shiftr const half_bits)
+ then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const)))
+ else if x =? (BinInt.Z.land const (wordmax_half_bits - 1))
+ then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const)))
+ else None.
+
+ Definition constant_to_scalar (x : BinInt.Z)
+ : option (Straightline.expr.scalar Z) :=
+ fold_right (fun c res => match res with
+ | Some s => Some s
+ | None => constant_to_scalar_single c x
+ end) None consts.
+
Definition invert_lower' {t} (e : @scalar var t) :
option (@scalar var Z) :=
match e in scalar t return option (@scalar var Z) with
@@ -8669,20 +8684,6 @@ Module PreFancy.
| LetInAppIdentZZ _ t r idc x f =>
of_straightline_ident idc t r x (fun y => of_straightline (f y))
end.
-
- Definition constant_to_scalar_single (const x : BinInt.Z) : option (@scalar var Z) :=
- if x =? (BinInt.Z.shiftr const half_bits)
- then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Shiftr half_bits (Primitive (t:=type.Z) const)))
- else if x =? (BinInt.Z.land const (wordmax_half_bits - 1))
- then Some (Cast {|lower := 0; upper:=wordmax_half_bits-1|} (Land (wordmax_half_bits-1) (Primitive (t:=type.Z) const)))
- else None.
-
- Definition constant_to_scalar_gen (consts : list BinInt.Z) (x : BinInt.Z)
- : option (Straightline.expr.scalar Z) :=
- fold_right (fun c res => match res with
- | Some s => Some s
- | None => constant_to_scalar_single c x
- end) None consts.
End with_var.
Section interp.
@@ -8719,8 +8720,8 @@ Module PreFancy.
End interp.
Section proofs.
- Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype)
- (constant_to_scalar : Z -> option (@scalar type.interp type.Z)).
+ Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z)
+ (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1).
Local Notation word_range := (r[0~>wordmax-1])%zrange.
Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange.
@@ -8799,7 +8800,22 @@ Module PreFancy.
| sc_ok_prim : forall p x, ok_scalar (@Primitive _ p x)
.
- Print ident.
+ Inductive is_halved : scalar type.Z -> Prop :=
+ | is_halved_lower :
+ forall x : scalar type.Z,
+ in_word_range (get_range x) ->
+ is_halved (Cast half_word_range (Land (wordmax_half_bits - 1) x))
+ | is_halved_upper :
+ forall x : scalar type.Z,
+ in_word_range (get_range x) ->
+ is_halved (Cast half_word_range (Shiftr half_bits x))
+ | is_halved_constant :
+ forall y z,
+ constant_to_scalar consts z = Some y ->
+ is_halved y ->
+ is_halved (Primitive (t:=type.Z) z)
+ .
+
Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop :=
| ok_add :
forall x : scalar (type.prod type.Z type.Z),
@@ -8835,14 +8851,16 @@ Module PreFancy.
0 <= n < wordmax ->
ok_ident type.Z type.Z x (ZRange.map (fun y => Z.land y n) (get_range x)) (ident.Z.land n)*)
| ok_shiftr :
- forall (x : scalar type.Z) n,
+ forall (x : scalar type.Z) n r,
in_word_range (get_range x) ->
0 <= n <= log2wordmax ->
- ok_ident type.Z type.Z x (ZRange.map (fun y => Z.shiftr y n) (get_range x)) (ident.Z.shiftr n)
+ r = ZRange.map (fun y => Z.shiftr y n) (get_range x) ->
+ ok_ident type.Z type.Z x r (ident.Z.shiftr n)
| ok_shiftl :
forall (x : scalar type.Z) n,
in_word_range (get_range x) ->
- 0 <= n < log2wordmax - Z.log2_up (upper (get_range x)) ->
+ 0 <= n ->
+ upper (get_range x) * 2 ^ n <= wordmax - 1 ->
ok_ident type.Z type.Z x word_range (ident.Z.shiftl n)
| ok_rshi :
forall (x : scalar (type.prod type.Z type.Z)) n,
@@ -8857,7 +8875,7 @@ Module PreFancy.
in_word_range (get_range z) ->
ok_ident _
type.Z
- (Pair (Pair (Snd x) y) z)
+ (Pair (Pair (Cast flag_range (Snd x)) y) z)
word_range
ident.Z.zselect
| ok_selm :
@@ -8891,6 +8909,16 @@ Module PreFancy.
x
word_range
ident.Z.add_modulo
+ | ok_mul :
+ forall x y : scalar type.Z,
+ is_halved x ->
+ is_halved y ->
+ ok_ident (type.prod type.Z type.Z)
+ type.Z
+ (Pair x y)
+ word_range
+ ident.Z.mul
+ (*
| ok_mulll :
forall x y : scalar type.Z,
in_word_range (get_range x) ->
@@ -8935,6 +8963,7 @@ Module PreFancy.
(Cast half_word_range (Shiftr half_bits y)))
word_range
ident.Z.mul
+*)
.
Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop :=
@@ -9061,7 +9090,7 @@ Module PreFancy.
Qed.
Lemma has_word_range_shiftl n r x :
- 0 <= n < log2wordmax - Z.log2_up (upper r) ->
+ 0 <= n -> upper r * 2 ^ n <= wordmax - 1 ->
@has_range type.Z r x ->
in_word_range r ->
@has_range type.Z word_range (x << n).
@@ -9073,18 +9102,7 @@ Module PreFancy.
apply andb_true_iff; split; apply Z.leb_le;
[ apply Z.shiftl_nonneg; solve [auto using in_word_range_nonneg] | ].
rewrite Z.shiftl_mul_pow2 by omega.
- destruct (dec (upper r = 0));
- [ match goal with H : _ = 0 |- _ => rewrite H end; pose proof wordmax_gt_2; lia | ].
- match goal with |- ?a * ?b <= _ =>
- transitivity (2 ^ (Z.log2_up a) * b) end.
- { apply Z.mul_le_mono_nonneg_r; auto with zarith; [ ].
- apply Z.log2_log2_up_spec.
- apply Z.le_neq; eauto using in_word_range_upper_nonneg with has_range. }
- { rewrite <-Z.pow_add_r by auto with zarith.
- assert (2 ^ (Z.log2_up (upper r) + n) < 2 ^ log2wordmax)
- by (apply Z.pow_lt_mono_r; omega).
- replace wordmax with (2 ^ log2wordmax) by reflexivity.
- omega. }
+ auto.
Qed.
Lemma has_word_range_rshi n x y :
@@ -9098,12 +9116,12 @@ Module PreFancy.
Qed.
Lemma in_word_range_spec r :
- 0 <= lower r -> upper r <= wordmax - 1 ->
- in_word_range r.
+ (0 <= lower r /\ upper r <= wordmax - 1)
+ <-> in_word_range r.
Proof.
intros; cbv [in_word_range is_tighter_than_bool].
- apply andb_true_iff.
- split; apply Z.leb_le; cbn [upper lower]; omega.
+ rewrite andb_true_iff.
+ intuition; apply Z.leb_le; cbn [upper lower]; try omega.
Qed.
Ltac destruct_scalar :=
@@ -9119,6 +9137,150 @@ Module PreFancy.
end
end.
+ Ltac extract_ok_scalar' level x :=
+ match goal with
+ | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ =>
+ match (eval compute in (4 <=? level)) with
+ | true => invert H; extract_ok_scalar' 3 x
+ | _ => fail
+ end
+ | H : ok_scalar (Pair (?f (?g x)) _) |- _ =>
+ match (eval compute in (3 <=? level)) with
+ | true => invert H; extract_ok_scalar' 2 x
+ | _ => fail
+ end
+ | H : ok_scalar (Pair _ (?f (?g x))) |- _ =>
+ match (eval compute in (3 <=? level)) with
+ | true => invert H; extract_ok_scalar' 2 x
+ | _ => fail
+ end
+ | H : ok_scalar (?f (?g x)) |- _ =>
+ match (eval compute in (2 <=? level)) with
+ | true => invert H; extract_ok_scalar' 1 x
+ | _ => fail
+ end
+ | H : ok_scalar (?g x) |- _ => invert H
+ | H : ok_scalar (Pair x _) |- _ => invert H
+ | H : ok_scalar (Pair _ x) |- _ => invert H
+ end.
+
+ Ltac extract_ok_scalar :=
+ match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end.
+
+ Lemma has_half_word_range_shiftr r x :
+ in_word_range r ->
+ @has_range type.Z r x ->
+ @has_range type.Z half_word_range (x >> half_bits).
+ Proof.
+ cbv [in_word_range is_tighter_than_bool].
+ rewrite andb_true_iff.
+ cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt.
+ { apply Z.shiftr_nonneg. omega. }
+ { pose proof half_bits_nonneg.
+ pose proof half_bits_squared.
+ assert (x >> half_bits < wordmax_half_bits); [|omega].
+ rewrite Z.shiftr_div_pow2 by auto.
+ apply Z.div_lt_upper_bound; Z.zero_bounds.
+ subst wordmax_half_bits half_bits.
+ rewrite <-Z.pow_add_r by omega.
+ rewrite Z.add_diag, Z.mul_div_eq, log2wordmax_even by omega.
+ autorewrite with zsimplify_fast. subst wordmax. omega. }
+ Qed.
+
+ Lemma has_half_word_range_land r x :
+ in_word_range r ->
+ @has_range type.Z r x ->
+ @has_range type.Z half_word_range (x &' (wordmax_half_bits - 1)).
+ Proof.
+ pose proof wordmax_half_bits_pos.
+ cbv [in_word_range is_tighter_than_bool].
+ rewrite andb_true_iff.
+ cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt.
+ { apply Z.land_nonneg; omega. }
+ { apply Z.land_upper_bound_r; omega. }
+ Qed.
+
+ Section constant_to_scalar.
+ Lemma constant_to_scalar_single_correct s x z :
+ 0 <= x <= wordmax - 1 ->
+ constant_to_scalar_single x z = Some s -> interp_scalar s = z.
+ Proof.
+ cbv [constant_to_scalar_single].
+ break_match; try discriminate; intros; Z.ltb_to_lt; subst;
+ try match goal with H : Some _ = Some _ |- _ => inversion H; subst end;
+ cbn [interp_scalar]; apply interp_cast_noop.
+ { apply has_half_word_range_shiftr with (r:=r[x~>x]%zrange);
+ cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega.
+ apply andb_true_iff; split; apply Z.leb_le; omega. }
+ { apply has_half_word_range_land with (r:=r[x~>x]%zrange);
+ cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega.
+ apply andb_true_iff; split; apply Z.leb_le; omega. }
+ Qed.
+
+ Lemma constant_to_scalar_correct s z :
+ constant_to_scalar consts z = Some s -> interp_scalar s = z.
+ Proof.
+ cbv [constant_to_scalar].
+ apply fold_right_invariant; try discriminate.
+ intros until 2; break_match; eauto using constant_to_scalar_single_correct.
+ Qed.
+
+ Lemma constant_to_scalar_single_cases x y z :
+ @constant_to_scalar_single type.interp x z = Some y ->
+ (y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive (t:=type.Z) x)))
+ \/ (y = Cast half_word_range (Shiftr half_bits (Primitive (t:=type.Z) x))).
+ Proof.
+ cbv [constant_to_scalar_single].
+ break_match; try discriminate; intros; Z.ltb_to_lt; subst;
+ try match goal with H : Some _ = Some _ |- _ => inversion H; subst end;
+ tauto.
+ Qed.
+
+ Lemma constant_to_scalar_cases y z :
+ @constant_to_scalar type.interp consts z = Some y ->
+ (exists x,
+ @has_range type.Z word_range x
+ /\ y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive x)))
+ \/ (exists x,
+ @has_range type.Z word_range x
+ /\ y = Cast half_word_range (Shiftr half_bits (Primitive x))).
+ Proof.
+ cbv [constant_to_scalar].
+ apply fold_right_invariant; try discriminate.
+ intros until 2; break_match; eauto; intros.
+ match goal with H : constant_to_scalar_single _ _ = _ |- _ =>
+ destruct (constant_to_scalar_single_cases _ _ _ H); subst end.
+ { left; eexists; split; eauto.
+ apply consts_ok; auto. }
+ { right; eexists; split; eauto.
+ apply consts_ok; auto. }
+ Qed.
+
+ Lemma ok_scalar_constant_to_scalar y z : constant_to_scalar consts z = Some y -> ok_scalar y.
+ Proof.
+ pose proof wordmax_half_bits_pos. pose proof half_bits_nonneg.
+ let H := fresh in
+ intro H; apply constant_to_scalar_cases in H; destruct H as [ [? ?] | [? ?] ]; intuition; subst;
+ cbn [has_range lower upper] in *; repeat constructor; cbn [lower get_range]; try apply Z.leb_refl; try omega.
+ assert (in_word_range r[x~>x]) by (apply in_word_range_spec; cbn [lower upper]; omega).
+ pose proof (has_half_word_range_shiftr r[x~>x] x ltac:(assumption) ltac:(cbv [has_range lower upper]; omega)).
+ cbn [has_range ZRange.map is_tighter_than_bool lower upper] in *.
+ apply andb_true_iff; cbn [lower upper]; split; apply Z.leb_le; omega.
+ Qed.
+ End constant_to_scalar.
+ Hint Resolve ok_scalar_constant_to_scalar.
+
+ Lemma is_halved_has_range x :
+ ok_scalar x ->
+ is_halved x ->
+ @has_range type.Z half_word_range (interp_scalar x).
+ Proof.
+ intro; pose proof (has_range_interp_scalar x ltac:(assumption)).
+ induction 1; cbn [interp_scalar] in *; intros; try assumption; [ ].
+ rewrite <-(constant_to_scalar_correct y z) by assumption.
+ eauto using has_range_interp_scalar.
+ Qed.
+
Lemma ident_interp_has_range s d x r idc:
ok_scalar x ->
ok_ident s d x r idc ->
@@ -9131,6 +9293,7 @@ Module PreFancy.
repeat match goal with
| H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt
| H : _ /\ _ |- _ => destruct H
+ | H : is_halved _ |- _ => apply is_halved_has_range in H; [ | extract_ok_scalar ]
| _ => progress subst
| _ => progress (cbv [in_word_range in_flag_range is_tighter_than_bool] in * )
| _ => progress (cbn [interp_scalar get_range has_range upper lower fst snd] in * )
@@ -9151,7 +9314,8 @@ Module PreFancy.
rewrite Z.div_sub_small by omega.
split; break_match; lia. }
{ apply has_range_shiftr; cbn [has_range]; omega. }
- { eapply has_word_range_shiftl; eauto using in_word_range_spec with has_range omega. }
+ { eapply has_word_range_shiftl; eauto with has_range omega.
+ apply in_word_range_spec ; omega. }
{ apply has_word_range_rshi; omega. }
{ rewrite Z.zselect_correct. break_match; omega. }
{ cbn [interp_scalar fst snd get_range] in *.
@@ -9160,40 +9324,10 @@ Module PreFancy.
rewrite Z.zselect_correct. break_match; omega. }
{ rewrite Z.add_modulo_correct.
break_match; Z.ltb_to_lt; omega. }
- { cbn [interp_scalar fst snd get_range upper lower] in *.
- pose proof half_bits_squared. nia. }
- { cbn [interp_scalar fst snd get_range upper lower] in *.
- pose proof half_bits_squared. nia. }
- { cbn [interp_scalar fst snd get_range upper lower] in *.
- pose proof half_bits_squared. nia. }
- { cbn [interp_scalar fst snd get_range upper lower] in *.
+ { cbn [interp_scalar has_range fst snd get_range upper lower] in *.
pose proof half_bits_squared. nia. }
Qed.
- Ltac extract_ok_scalar' level x :=
- match goal with
- | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ =>
- match (eval compute in (4 <=? level)) with
- | true => invert H; extract_ok_scalar' 3 x
- | _ => fail
- end
- | H : ok_scalar (Pair (?f (?g x)) _) |- _ =>
- match (eval compute in (3 <=? level)) with
- | true => invert H; extract_ok_scalar' 2 x
- | _ => fail
- end
- | H : ok_scalar (?f (?g x)) |- _ =>
- match (eval compute in (2 <=? level)) with
- | true => invert H; extract_ok_scalar' 1 x
- | _ => fail
- end
- | H : ok_scalar (?g x) |- _ => invert H
- end.
-
-
- Ltac extract_ok_scalar :=
- match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end.
-
Lemma has_flag_range_cc_m r x :
@has_range type.Z r x ->
in_word_range r ->
@@ -9239,16 +9373,76 @@ Module PreFancy.
| _ => rewrite interp_cast_noop by assumption
end.
+ Lemma is_halved_cases x :
+ is_halved x ->
+ ok_scalar x ->
+ (exists y,
+ invert_lower consts x = Some y
+ /\ invert_upper consts x = None
+ /\ interp_scalar y &' (wordmax_half_bits - 1) = interp_scalar x)
+ \/ (exists y,
+ invert_lower consts x = None
+ /\ invert_upper consts x = Some y
+ /\ interp_scalar y >> half_bits = interp_scalar x).
+ Proof.
+ induction 1; intros; cbn; rewrite ?Z.eqb_refl; cbn.
+ { left. eexists; repeat split; auto.
+ rewrite interp_cast_noop; [ reflexivity | ].
+ apply has_half_word_range_land with (r:=get_range x); auto.
+ apply has_range_interp_scalar; extract_ok_scalar. }
+ { right. eexists; repeat split; auto.
+ rewrite interp_cast_noop; [ reflexivity | ].
+ apply has_half_word_range_shiftr with (r:=get_range x); auto.
+ apply has_range_interp_scalar; extract_ok_scalar. }
+ { match goal with H : constant_to_scalar _ _ = Some _ |- _ =>
+ rewrite H;
+ let P := fresh in
+ destruct (constant_to_scalar_cases _ _ H) as [ [? [? ?] ] | [? [? ?] ] ];
+ subst; cbn; rewrite ?Z.eqb_refl; cbn
+ end.
+ { left; eexists; repeat split; auto.
+ erewrite <-constant_to_scalar_correct by eassumption.
+ subst. cbn.
+ rewrite interp_cast_noop; [ reflexivity | ].
+ eapply has_half_word_range_land with (r:=word_range); auto.
+ cbv [in_word_range is_tighter_than_bool].
+ rewrite !Z.leb_refl; reflexivity. }
+ { right; eexists; repeat split; auto.
+ erewrite <-constant_to_scalar_correct by eassumption.
+ subst. cbn.
+ rewrite interp_cast_noop; [ reflexivity | ].
+ eapply has_half_word_range_shiftr with (r:=word_range); auto.
+ cbv [in_word_range is_tighter_than_bool].
+ rewrite !Z.leb_refl; reflexivity. } }
+ Qed.
+
+ Lemma of_straightline_ident_mul_correct t x y g :
+ is_halved x ->
+ is_halved y ->
+ ok_scalar (Pair x y) ->
+ @has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) ->
+ interp (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) =
+ interp (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))).
+ Proof.
+ intros Hx Hy Hok ?; invert Hok; cbn [interp_scalar of_straightline_ident].
+ destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ];
+ rewrite ?Pxlow, ?Pxhigh;
+ destruct (is_halved_cases y Hy ltac:(assumption)) as [ [? [Pylow [Pyhigh Pyi] ] ] | [? [Pylow [Pyhigh Pyi] ] ] ];
+ rewrite ?Pylow, ?Pyhigh;
+ cbn; rewrite Pxi, Pyi; rewrite interp_cast_noop by auto; reflexivity.
+ Qed.
+
Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g :
ok_ident s d x r idc ->
ok_scalar x ->
- interp (of_straightline_ident dummy_arrow constant_to_scalar idc t r x g) =
+ interp (of_straightline_ident dummy_arrow consts idc t r x g) =
interp (g (ident.interp idc (interp_scalar x))).
Proof.
intros.
pose proof wordmax_half_bits_pos.
pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)).
- induction H; cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell invert_upper invert_lower invert_upper' invert_lower'] in *; intros;
+ induction H; try solve [auto using of_straightline_ident_mul_correct];
+ cbn [of_straightline_ident ident.interp ident.gen_interp invert_selm invert_sell] in *; intros;
rewrite ?Z.eqb_refl; cbn [interp interp_ident andb]; try destruct_scalar;
repeat match goal with
| _ => progress (cbn [fst snd interp_scalar] in * )
@@ -9261,16 +9455,12 @@ Module PreFancy.
| _ => rewrite interp_cast_noop by assumption
| _ => rewrite interp_cast2_noop by assumption
| _ => reflexivity
- end; [ | | | ]; (* leftover cases are the 4 kinds of multiplication *)
- match goal with
- | H1 : in_word_range (get_range ?x), H2 : in_word_range (get_range ?y) |- _ =>
- extract_ok_scalar' 4 x; extract_ok_scalar' 4 y
- end; rewrite_cast_noop_in_mul; reflexivity.
+ end.
Qed.
Lemma of_straightline_correct {t} (e : expr t) :
ok_expr e ->
- interp (of_straightline dummy_arrow constant_to_scalar e) = Straightline.expr.interp (interp_ident:=@ident.interp) e.
+ interp (of_straightline dummy_arrow consts e) = Straightline.expr.interp (interp_ident:=@ident.interp) e.
Proof.
induction 1; cbn [of_straightline]; intros;
repeat match goal with
@@ -9287,16 +9477,19 @@ Module PreFancy.
Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d))
(var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d :=
- @of_straightline log2wordmax var dummy_arrow (@constant_to_scalar_gen log2wordmax var consts) _ (Straightline.of_Expr e var x dummy_arrow).
+ @of_straightline log2wordmax var dummy_arrow consts _ (Straightline.of_Expr e var x dummy_arrow).
Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d))
(e' : (type.interp s -> Uncurried.expr.expr d))
(x : type.interp s) dummy_arrow :
e type.interp = Abs e' ->
1 < log2wordmax ->
+ log2wordmax mod 2 = 0 ->
Straightline.expr.ok_expr (e' x) ->
- ok_expr log2wordmax (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit))) (e' x)) ->
- (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e (fun _ : type => unit)))%nat ->
+ (forall x0 : Z, In x0 consts -> 0 <= x0 <= 2 ^ log2wordmax - 1) ->
+ ok_expr log2wordmax consts
+ (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e _)) (e' x)) ->
+ (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e _))%nat ->
interp log2wordmax (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x.
Proof.
intro He'; intros; cbv [of_Expr Straightline.of_Expr].
@@ -10542,9 +10735,9 @@ Ltac solve_rmontred_nocache := solve_rop_nocache MontgomeryReduction.rmontred_co
Module Montgomery256.
- Definition N := (2^256-2^224+2^192+2^96-1).
+ Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1).
Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849).
- Definition R := (2^256).
+ Definition R := Eval lazy in (2^256).
Definition machine_wordsize := 256.
Derive montred256
@@ -10552,6 +10745,104 @@ Module Montgomery256.
As montred256_correct.
Proof. Time solve_rmontred machine_wordsize. Time Qed.
+
+ Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) :
+ forall (lo hi : Z),
+ 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N ->
+ MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 (lo, hi) = ((lo + R * hi) * R') mod N.
+ Proof.
+ intros.
+ apply MontgomeryReduction.montred'_correct with (T:=lo + R * hi) (R':=R');
+ try match goal with
+ | |- context[R'] => assumption
+ | |- context [lo] =>
+ try assumption; progress autorewrite with zsimplify cancel_pair; reflexivity
+ end; lazy; try split; congruence.
+ Qed.
+
+ (* Note: If this is not factored out, then for some reason Qed takes forever in montred256_correct_full. *)
+ Lemma montred256_correct_proj2 :
+ forall xy : type.interp (type.prod type.Z type.Z),
+ ZRange.type.option.is_bounded_by
+ (t:=type.prod type.Z type.Z)
+ (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange)
+ xy = true ->
+ expr.Interp (@ident.interp) montred256 xy = MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 xy.
+ Proof. intros; destruct (montred256_correct xy); assumption. Qed.
+
+ Lemma montred256_correct_full R' (R'_correct : Z.equiv_modulo N (R * R') 1) :
+ forall (lo hi : Z),
+ 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N ->
+ expr.interp (@ident.interp) (montred256 type.interp) (lo, hi) = ((lo + R * hi) * R') mod N.
+ Proof.
+ intros.
+ rewrite <-montred'_correct_specialized by assumption.
+ rewrite <-montred256_correct_proj2.
+ { cbv [expr.Interp type.uncurried_domain type.uncurry type.final_codomain].
+ reflexivity. }
+ { cbn. rewrite !andb_true_iff. cbv [R N] in *.
+ repeat split; apply Z.leb_le; omega. }
+ Qed.
+
+ (* TODO : maybe move these ok_expr tactics somewhere else *)
+ Ltac ok_expr_step' :=
+ match goal with
+ | _ => assumption
+ | |- context [PreFancy.ok_ident] => constructor
+ | |- context [PreFancy.ok_scalar] => constructor; try omega
+ | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ]
+ | |- context [PreFancy.is_halved] => constructor
+ | |- context [PreFancy.in_word_range] => lazy; reflexivity
+ | |- context [PreFancy.in_flag_range] => lazy; reflexivity
+ | |- context [PreFancy.get_range] =>
+ cbn [PreFancy.get_range lower upper fst snd ZRange.map]
+ | x : type.interp (type.prod _ _) |- _ => destruct x
+ | |- (_ <=? _)%zrange = true =>
+ match goal with
+ | |- context [PreFancy.get_range_var] =>
+ cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower R N] in *; cbn;
+ apply andb_true_iff; split; apply Z.leb_le
+ | _ => lazy
+ end; omega || reflexivity
+ | |- @eq zrange _ _ => lazy; reflexivity
+ | |- _ <= _ => cbv [machine_wordsize]; omega
+ | |- _ <= _ <= _ => cbv [machine_wordsize]; omega
+ end; intros.
+
+ (* TODO : maybe move these ok_expr tactics somewhere else *)
+ Ltac ok_expr_step :=
+ match goal with
+ | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step'
+ end; intros; cbn [Nat.max].
+
+ Definition montred256_prefancy' := PreFancy.of_Expr machine_wordsize [N;N'] montred256.
+
+ Derive montred256_prefancy
+ SuchThat (montred256_prefancy = montred256_prefancy' type.interp)
+ As montred256_prefancy_eq.
+ Proof. lazy - [type.interp]; reflexivity. Qed.
+
+ Lemma montred256_prefancy_correct R' (R'_correct : Z.equiv_modulo N (R * R') 1) :
+ forall (lo hi : Z) dummy_arrow,
+ 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N ->
+ @PreFancy.interp machine_wordsize type.Z (montred256_prefancy (lo,hi) dummy_arrow) = ((lo + R * hi) * R') mod N.
+ Proof.
+ intros. rewrite montred256_prefancy_eq; cbv [montred256_prefancy'].
+ erewrite PreFancy.of_Expr_correct.
+ { apply montred256_correct_full; assumption. }
+ { reflexivity. }
+ { lazy; reflexivity. }
+ { lazy; reflexivity. }
+ { repeat constructor. }
+ { cbv [In N N']; intros; intuition; subst; cbv; congruence. }
+ { assert (340282366920938463463374607431768211455 * 2 ^ 128 <= 2 ^ machine_wordsize - 1) as shiftl_128_ok by (lazy; congruence).
+ repeat (ok_expr_step; [ ]).
+ ok_expr_step.
+ lazy; congruence.
+ constructor. }
+ { lazy. omega. }
+ Qed.
+
Import PrintingNotations.
Set Printing Width 10000.
@@ -10585,8 +10876,7 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z *
: Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z))
*)
- Definition montred256_prefancy :=
- Eval lazy in (PreFancy.of_Expr machine_wordsize [N;N'] montred256).
+
Import PreFancy.
Import PreFancy.Notations.
Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951).