aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jade Philipoom <jadep@google.com>2018-05-30 21:33:17 +0200
committerGravatar Jade Philipoom <jadep@google.com>2018-05-31 15:13:41 +0200
commite4651284bb30a664ef4ec190dce4b01b02822f53 (patch)
tree682f01b64e813722b520049ab896d00005b57ab6 /src
parente6119c9595326a910d177488bf44aab3cc275e49 (diff)
Proved pre-fancy barrett reduction correct (except 1 admit for bounds
that are correct but for which bounds relaxation loses necessary information) and add explanatory comments.
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v808
1 files changed, 527 insertions, 281 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index bb2255547..c277a23ab 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -8836,9 +8836,9 @@ Module PreFancy.
match e with
| Scalar _ s => Scalar s
| LetInAppIdentZ _ t r idc x f =>
- of_straightline_ident idc t r x (fun y => of_straightline (f y))
+ of_straightline_ident idc t r[0~>wordmax-1]%zrange x (fun y => of_straightline (f y))
| LetInAppIdentZZ _ t r idc x f =>
- of_straightline_ident idc t r x (fun y => of_straightline (f y))
+ of_straightline_ident idc t (r[0~>wordmax-1], r[0~>1])%zrange x (fun y => of_straightline (f y))
end.
End with_var.
@@ -8988,14 +8988,16 @@ Module PreFancy.
(word_range, flag_range)
(ident.Z.add_get_carry_concrete wordmax)
| ok_addc :
- forall c x y : scalar type.Z,
+ forall (c x y : scalar type.Z) outr,
in_flag_range (get_range c) ->
in_word_range (get_range x) ->
in_word_range (get_range y) ->
+ lower outr = 0 ->
+ (0 <= upper (get_range c) + upper (get_range x) + upper (get_range y) <= upper outr \/ outr = word_range) ->
ok_ident _
(type.prod type.Z type.Z)
(Pair (Pair c x) y)
- (word_range, flag_range)
+ (outr, flag_range)
(ident.Z.add_with_get_carry_concrete wordmax)
| ok_sub :
forall x y : scalar type.Z,
@@ -9017,11 +9019,15 @@ Module PreFancy.
(word_range, flag_range)
(ident.Z.sub_with_get_borrow_concrete wordmax)
| ok_rshi :
- forall (x : scalar (type.prod type.Z type.Z)) n,
+ forall (x : scalar (type.prod type.Z type.Z)) n outr,
in_word_range (fst (get_range x)) ->
in_word_range (snd (get_range x)) ->
- 0 <= n < 2 * log2wordmax ->
- ok_ident (type.prod type.Z type.Z) type.Z x word_range (ident.Z.rshi_concrete wordmax n)
+ (* note : using [outr] rather than [word_range] allows for cases where the result has been put in a smaller word size. *)
+ lower outr = 0 ->
+ 0 <= n ->
+ ((0 <= (upper (snd (get_range x)) + upper (fst (get_range x)) * wordmax) / 2^n <= upper outr)
+ \/ outr = word_range) ->
+ ok_ident (type.prod type.Z type.Z) type.Z x outr (ident.Z.rshi_concrete wordmax n)
| ok_selc :
forall (x : scalar (type.prod type.Z type.Z)) (y z : scalar type.Z),
in_flag_range (snd (get_range x)) ->
@@ -9078,14 +9084,16 @@ Module PreFancy.
| ok_of_scalar : forall t s, ok_scalar s -> @ok_expr t (Scalar s)
| ok_letin_z : forall s d r idc x f,
ok_ident _ type.Z x r idc ->
+ (r <=? word_range)%zrange = true ->
ok_scalar x ->
- (forall y, has_range r y -> ok_expr (f y)) ->
+ (forall y, has_range (t:=type.Z) r y -> ok_expr (f y)) ->
ok_expr (@LetInAppIdentZ _ _ s d r idc x f)
| ok_letin_zz : forall s d r idc x f,
- ok_ident _ (type.prod type.Z type.Z) x r idc ->
+ ok_ident _ (type.prod type.Z type.Z) x (r, flag_range) idc ->
+ (r <=? word_range)%zrange = true ->
ok_scalar x ->
- (forall y, has_range r y -> ok_expr (f y)) ->
- ok_expr (@LetInAppIdentZZ _ _ s d r idc x f)
+ (forall y, has_range (t:=type.Z * type.Z) (r, flag_range) y -> ok_expr (f y)) ->
+ ok_expr (@LetInAppIdentZZ _ _ s d (r, flag_range) idc x f)
.
Ltac invert H :=
@@ -9205,14 +9213,25 @@ Module PreFancy.
auto.
Qed.
- Lemma has_word_range_rshi n x y :
+ Lemma has_range_rshi r n x y :
0 <= n ->
- @has_range type.Z word_range (Z.rshi wordmax x y n).
+ 0 <= x ->
+ 0 <= y ->
+ lower r = 0 ->
+ (0 <= (y + x * wordmax) / 2^n <= upper r \/ r = word_range) ->
+ @has_range type.Z r (Z.rshi wordmax x y n).
Proof.
pose proof wordmax_gt_2.
- intros; rewrite Z.rshi_correct by omega.
- match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
- cbn [has_range lower upper]; lia.
+ intros. cbv [has_range].
+ rewrite Z.rshi_correct by omega.
+ match goal with |- context [?x mod ?m] =>
+ pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ split; [lia|].
+ intuition.
+ { destruct (Z_lt_dec (upper r) wordmax); [ | lia].
+ rewrite Z.mod_small by (split; Z.zero_bounds; omega).
+ omega. }
+ { subst r. cbn [upper]. omega. }
Qed.
Lemma in_word_range_spec r :
@@ -9417,7 +9436,15 @@ Module PreFancy.
autorewrite with to_div_mod.
match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
rewrite Z.div_between_0_if by omega.
- split; break_match; lia. }
+ match goal with H : _ \/ _ |- _ => destruct H; subst end.
+ { split; break_match; try lia.
+ destruct (Z_lt_dec (upper outr) wordmax).
+ { match goal with |- _ <= ?y mod _ <= ?u =>
+ assert (y <= u) by nia end.
+ rewrite Z.mod_small by omega. omega. }
+ { match goal with|- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
+ omega. } }
+ { split; break_match; cbn; lia. } }
{
autorewrite with to_div_mod.
match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
@@ -9429,7 +9456,12 @@ Module PreFancy.
match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end.
rewrite Z.div_sub_small by omega.
split; break_match; lia. }
- { apply has_word_range_rshi; omega. }
+ { apply has_range_rshi; try nia; [ ].
+ match goal with H : context [upper ?ra + upper ?rb * wordmax] |- context [?a + ?b * wordmax] =>
+ assert ((a + b * wordmax) / 2^n <= (upper ra + upper rb * wordmax) / 2^n) by (apply Z.div_le_mono; Z.zero_bounds; nia)
+ end.
+ match goal with H : _ \/ ?P |- _ \/ ?P => destruct H; [left|tauto] end.
+ split; Z.zero_bounds; nia. }
{ rewrite Z.zselect_correct. break_match; omega. }
{ cbn [interp_scalar fst snd get_range] in *.
rewrite Z.zselect_correct. break_match; omega. }
@@ -9529,20 +9561,34 @@ Module PreFancy.
rewrite !Z.leb_refl; reflexivity. } }
Qed.
- Lemma of_straightline_ident_mul_correct t x y g :
+ Lemma halved_mul_range x y :
+ ok_scalar (Pair x y) ->
+ is_halved x ->
+ is_halved y ->
+ 0 <= interp_scalar x * interp_scalar y < wordmax.
+ Proof.
+ intro Hok; invert Hok. intros.
+ repeat match goal with H : _ |- _ => apply is_halved_has_range in H; [|assumption] end.
+ cbv [has_range lower upper] in *.
+ pose proof half_bits_squared. nia.
+ Qed.
+
+ Lemma of_straightline_ident_mul_correct r t x y g :
is_halved x ->
is_halved y ->
ok_scalar (Pair x y) ->
+ (word_range <=? r)%zrange = true ->
@has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) ->
- @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t word_range (Pair x y) g) =
+ @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t r (Pair x y) g) =
@interp interp_cast _ (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))).
Proof.
- intros Hx Hy Hok ?; invert Hok; cbn [interp_scalar of_straightline_ident].
+ intros Hx Hy Hok ? ?; invert Hok; cbn [interp_scalar of_straightline_ident];
destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ];
rewrite ?Pxlow, ?Pxhigh;
destruct (is_halved_cases y Hy ltac:(assumption)) as [ [? [Pylow [Pyhigh Pyi] ] ] | [? [Pylow [Pyhigh Pyi] ] ] ];
rewrite ?Pylow, ?Pyhigh;
- cbn; rewrite Pxi, Pyi; rewrite interp_cast_noop by auto; reflexivity.
+ cbn; rewrite Pxi, Pyi; assert (0 <= interp_scalar x * interp_scalar y < wordmax) by (auto using halved_mul_range);
+ rewrite interp_cast_noop by (cbv [is_tighter_than_bool] in *; cbn [has_range upper lower] in *; rewrite andb_true_iff in *; intuition; Z.ltb_to_lt; lia); reflexivity.
Qed.
Lemma has_word_range_mod_small x:
@@ -9629,25 +9675,48 @@ Module PreFancy.
replace y with (fst y, snd y) by (destruct y; reflexivity)
end; autorewrite with to_div_mod; solve [repeat (f_equal; try ring)].
- Lemma of_straightline_ident_correct s d t x r (idc : ident.ident s d) g :
+ Fixpoint is_tighter_than_bool_range_type t : range_type t -> range_type t -> bool :=
+ match t with
+ | type.type_primitive type.Z => (fun r1 r2 => (r1 <=? r2)%zrange)
+ | type.prod a b => fun r1 r2 =>
+ (is_tighter_than_bool_range_type a (fst r1) (fst r2))
+ && (is_tighter_than_bool_range_type b (snd r1) (snd r2))
+ | _ => fun _ _ => true
+ end.
+
+ Definition range_ok {t} : range_type t -> Prop :=
+ match t with
+ | type.type_primitive type.Z => fun r => in_word_range r
+ | type.prod type.Z type.Z => fun r => in_word_range (fst r) /\ snd r = flag_range
+ | _ => fun _ => False
+ end.
+
+ Lemma of_straightline_ident_correct s d t x r r' (idc : ident.ident s d) g :
ok_ident s d x r idc ->
+ range_ok r' ->
+ is_tighter_than_bool_range_type d r r' = true ->
ok_scalar x ->
- @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r x g) =
+ @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r' x g) =
@interp interp_cast _ (g (ident.interp idc (interp_scalar x))).
Proof.
intros.
pose proof wordmax_half_bits_pos.
pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)).
- induction H; try solve [auto using of_straightline_ident_mul_correct];
+ match goal with H : ok_ident _ _ _ _ _ |- _ => induction H end;
+ try solve [auto using of_straightline_ident_mul_correct];
+ cbv [is_tighter_than_bool_range_type is_tighter_than_bool range_ok] in *;
cbn [of_straightline_ident ident.interp ident.gen_interp
invert_selm invert_sell] in *;
intros; rewrite ?Z.eqb_refl; cbn [andb];
try match goal with |- context [invert_shift] => break_match end;
cbn [interp interp_ident]; try destruct_scalar;
- repeat match goal with
+ repeat match goal with
| _ => progress (cbn [fst snd interp_scalar] in * )
| _ => progress break_match; [ ]
| _ => progress autorewrite with zsimplify_fast
+ | _ => progress Z.ltb_to_lt
+ | H : _ /\ _ |- _ => destruct H
+ | _ => rewrite andb_true_iff in *
| _ => rewrite interp_cast_noop with (r:=flag_range) in *
by (apply has_flag_range_cc_m'; auto; extract_ok_scalar)
| _ => rewrite interp_cast_noop with (r:=flag_range) in *
@@ -9662,10 +9731,10 @@ Module PreFancy.
by (eapply has_range_loosen;
[apply has_range_interp_scalar; extract_ok_scalar|];
assumption)
- | _ => rewrite interp_cast_noop by assumption
- | _ => rewrite interp_cast2_noop by assumption
+ | _ => rewrite interp_cast_noop by (cbn [has_range fst snd] in *; split; lia)
+ | _ => rewrite interp_cast2_noop by (cbn [has_range fst snd] in *; split; lia)
| _ => reflexivity
- end.
+ end.
Qed.
Lemma of_straightline_correct {t} (e : expr t) :
@@ -9676,12 +9745,15 @@ Module PreFancy.
induction 1; cbn [of_straightline]; intros;
repeat match goal with
| _ => progress cbn [Straightline.expr.interp]
- | _ => rewrite of_straightline_ident_correct by auto
- | _ => rewrite interp_cast_noop by auto using ident_interp_has_range
- | _ => rewrite interp_cast2_noop by auto using ident_interp_has_range
- | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by auto using ident_interp_has_range
+ | _ => erewrite of_straightline_ident_correct
+ by (cbv [range_ok is_tighter_than_bool_range_type];
+ eauto using in_word_range_word_range;
+ try apply andb_true_iff; auto)
+ | _ => rewrite interp_cast_noop by eauto using has_range_loosen, ident_interp_has_range
+ | _ => rewrite interp_cast2_noop by eauto using has_range_loosen, ident_interp_has_range
+ | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by eauto using has_range_loosen, ident_interp_has_range
| _ => reflexivity
- end.
+ end.
Qed.
End proofs.
@@ -9726,10 +9798,12 @@ Module PreFancy.
Proof.
induction 1; intros; cbn [of_straightline interp].
{ apply replace_interp_cast_scalar; auto. }
- { rewrite !of_straightline_ident_correct by auto.
+ { erewrite !of_straightline_ident_correct by (eauto; cbv [range_ok]; apply in_word_range_word_range).
rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto.
eauto using ident_interp_has_range. }
- { rewrite !of_straightline_ident_correct by auto.
+ { erewrite !of_straightline_ident_correct by
+ (eauto; try solve [cbv [range_ok]; split; auto using in_word_range_word_range];
+ cbv [is_tighter_than_bool_range_type]; apply andb_true_iff; split; auto).
rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto.
eauto using ident_interp_has_range. }
Qed.
@@ -10248,6 +10322,34 @@ Module Prod.
(Instr (SUB 0) lo (hi, y)
(Instr ADDM lo (lo, RegZero, RegMod)
(Ret lo))))))).
+
+ (* Barrett reduction -- this is only the "reduce" part, excluding the initial multiplication. *)
+ Definition MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : @Fancy.expr register :=
+ let q1Bottom256 := scratchp1 in
+ let muSelect := scratchp2 in
+ let q2 := scratchp3 in
+ let q2High := scratchp4 in
+ let q2High2 := scratchp5 in
+ let q3 := scratchp1 in
+ let r2 := scratchp2 in
+ let r2High := scratchp3 in
+ let maybeM := scratchp1 in
+ Instr SELM muSelect (RegMuLow, RegZero)
+ (Instr (RSHI 255) q1Bottom256 (xHigh, x)
+ (Mul256x256 q2 q2High q1Bottom256 RegMuLow scratchp5
+ (Instr (RSHI 255) q2High2 (RegZero, xHigh)
+ (Instr (ADD 0) q2High (q2High, q1Bottom256)
+ (Instr (ADDC 0) q2High2 (q2High2, RegZero)
+ (Instr (ADD 0) q2High (q2High, muSelect)
+ (Instr (ADDC 0) q2High2 (q2High2, RegZero)
+ (Instr (RSHI 1) q3 (q2High2, q2High)
+ (Mul256x256 r2 r2High RegMod q3 scratchp4
+ (Instr (SUB 0) muSelect (x, r2)
+ (Instr (SUBC 0) xHigh (xHigh, r2High)
+ (Instr SELL maybeM (RegMod, RegZero)
+ (Instr (SUB 0) q3 (muSelect, maybeM)
+ (Instr ADDM x (q3, RegZero, RegMod)
+ (Ret x))))))))))))))).
End Prod.
Module ProdEquiv.
@@ -10434,6 +10536,74 @@ Module ProdEquiv.
subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128].
lia. }
Qed.
+
+ Lemma mulll_comm rd x y cont cc ctx :
+ ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx.
+ Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
+
+ Lemma mulhh_comm rd x y cont cc ctx :
+ ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx.
+ Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
+
+ Lemma mullh_mulhl rd x y cont cc ctx :
+ ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx.
+ Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
+
+ Lemma add_comm rd x y cont cc ctx :
+ 0 <= ctx x < 2^256 ->
+ 0 <= ctx y < 2^256 ->
+ ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx.
+ Proof.
+ intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm.
+ rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity.
+ Qed.
+
+ Lemma addc_comm rd x y cont cc ctx :
+ 0 <= ctx x < 2^256 ->
+ 0 <= ctx y < 2^256 ->
+ ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx.
+ Proof.
+ intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)).
+ rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity.
+ Qed.
+
+ Ltac push_value_unused :=
+ repeat match goal with
+ | |- ~ In _ _ => cbn; intuition; congruence
+ | _ => apply ProdEquiv.value_unused_overwrite
+ | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ]
+ | _ => apply ProdEquiv.value_unused_ret; congruence
+ end.
+
+ Ltac remember_single_result :=
+ match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] =>
+ let x := fresh "x" in
+ let y := fresh "y" in
+ let Heqx := fresh "Heqx" in
+ remember (Fancy.spec i args cc) as x eqn:Heqx;
+ remember (x mod w) as y
+ end.
+ Ltac step_both_sides :=
+ match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 =>
+ rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2);
+ cbn - [Fancy.interp Fancy.spec];
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence;
+ remember_single_result;
+ lazymatch goal with
+ | |- context [Fancy.spec i _ _] =>
+ let Heqa1 := fresh in
+ let Heqa2 := fresh in
+ remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1;
+ remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2;
+ cbn in Heqa1; cbn in Heqa2;
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence;
+ repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence;
+ let a1 := match type of Heqa1 with _ = ?a1 => a1 end in
+ let a2 := match type of Heqa2 with _ = ?a2 => a2 end in
+ (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2)
+ | _ => idtac
+ end
+ end.
End ProdEquiv.
(* Lemmas to help prove that a fancy and prefancy expression have the
@@ -11073,7 +11243,7 @@ Ltac solve_rbarrett_red_nocache := solve_rop_nocache BarrettReduction.rbarrett_r
Module Barrett256.
- Definition M := (2^256-2^224+2^192+2^96-1).
+ Definition M := Eval lazy in (2^256-2^224+2^192+2^96-1).
Definition machine_wordsize := 256.
Derive barrett_red256
@@ -11081,6 +11251,213 @@ Module Barrett256.
As barrett_red256_correct.
Proof. Time solve_rbarrett_red machine_wordsize. Time Qed.
+ Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize).
+ Definition barrett_red256_prefancy' := PreFancy.of_Expr machine_wordsize [M; muLow] barrett_red256.
+
+ Derive barrett_red256_prefancy
+ SuchThat (barrett_red256_prefancy = barrett_red256_prefancy' type.interp)
+ As barrett_red256_prefancy_eq.
+ Proof. lazy - [type.interp]; reflexivity. Qed.
+
+ Lemma barrett_reduce_correct_specialized :
+ forall (xLow xHigh : Z),
+ 0 <= xLow < 2 ^ machine_wordsize ->
+ 0 <= xHigh < M ->
+ BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M.
+ Proof.
+ intros.
+ apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *;
+ try omega;
+ try match goal with
+ | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega
+ end; lazy; try split; congruence.
+ Qed.
+
+ (* Note: If this is not factored out, then for some reason Qed takes forever in barrett_red256_correct_full. *)
+ Lemma barrett_red256_correct_proj2 :
+ forall xy : type.interp (type.prod type.Z type.Z),
+ ZRange.type.option.is_bounded_by
+ (t:=type.prod type.Z type.Z)
+ (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange)
+ xy = true ->
+ expr.Interp (@ident.interp) barrett_red256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (fun xy => BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 (fst xy) (snd xy)) xy.
+ Proof. intros; destruct (barrett_red256_correct xy); assumption. Qed.
+ Lemma barrett_red256_correct_proj2' :
+ forall x y : Z,
+ ZRange.type.option.is_bounded_by
+ (t:=type.prod type.Z type.Z)
+ (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange)
+ (x, y) = true ->
+ expr.Interp (@ident.interp) barrett_red256 (x, y) = BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 x y.
+ Proof. intros; rewrite barrett_red256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed.
+
+ Lemma barrett_red256_correct_full :
+ forall (xLow xHigh : Z),
+ 0 <= xLow < 2 ^ machine_wordsize ->
+ 0 <= xHigh < M ->
+ expr.interp (@ident.interp) (barrett_red256 type.interp) (xLow, xHigh) = (xLow + 2 ^ machine_wordsize * xHigh) mod M.
+ Proof.
+ intros.
+ rewrite <-barrett_reduce_correct_specialized by assumption.
+ rewrite <-barrett_red256_correct_proj2'.
+ { cbv [expr.Interp type.uncurried_domain type.uncurry type.final_codomain].
+ reflexivity. }
+ { cbn. rewrite !andb_true_iff. cbv [machine_wordsize M] in *.
+ cbn in *. repeat split; apply Z.leb_le; omega. }
+ Qed.
+
+ (* TODO : maybe move these ok_expr tactics somewhere else *)
+ Ltac ok_expr_step' :=
+ match goal with
+ | _ => assumption
+ | |- _ <= _ <= _ \/ @eq zrange _ _ =>
+ right; lazy; try split; congruence
+ | |- _ <= _ <= _ \/ @eq zrange _ _ =>
+ left; lazy; try split; congruence
+ | |- context [PreFancy.ok_ident] => constructor
+ | |- context [PreFancy.ok_scalar] => constructor; try omega
+ | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ]
+ | |- context [PreFancy.is_halved] => constructor
+ | |- context [PreFancy.in_word_range] => lazy; reflexivity
+ | |- context [PreFancy.in_flag_range] => lazy; reflexivity
+ | |- context [PreFancy.get_range] =>
+ cbn [PreFancy.get_range lower upper fst snd ZRange.map]
+ | x : type.interp (type.prod _ _) |- _ => destruct x
+ | |- (_ <=? _)%zrange = true =>
+ match goal with
+ | |- context [PreFancy.get_range_var] =>
+ cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower machine_wordsize M muLow] in *; cbn;
+ apply andb_true_iff; split; apply Z.leb_le
+ | _ => lazy
+ end; omega || reflexivity
+ | |- @eq zrange _ _ => lazy; reflexivity
+ | |- _ <= _ => cbv [machine_wordsize]; omega
+ | |- _ <= _ <= _ => cbv [machine_wordsize]; omega
+ end; intros.
+
+ (* TODO : maybe move these ok_expr tactics somewhere else *)
+ Ltac ok_expr_step :=
+ match goal with
+ | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step'
+ end; intros; cbn [Nat.max].
+
+ Lemma barrett_red256_prefancy_correct :
+ forall xLow xHigh dummy_arrow,
+ 0 <= xLow < 2 ^ machine_wordsize ->
+ 0 <= xHigh < M ->
+ @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (barrett_red256_prefancy (xLow, xHigh) dummy_arrow) = (xLow + 2 ^ machine_wordsize * xHigh) mod M.
+ Proof.
+ intros. rewrite barrett_red256_prefancy_eq; cbv [barrett_red256_prefancy'].
+ erewrite PreFancy.of_Expr_correct.
+ { apply barrett_red256_correct_full; try assumption; reflexivity. }
+ { reflexivity. }
+ { lazy; reflexivity. }
+ { lazy; reflexivity. }
+ { repeat constructor. }
+ { cbv [In M muLow]; intros; intuition; subst; cbv; congruence. }
+ { let r := (eval compute in (2 ^ machine_wordsize)) in
+ replace (2^machine_wordsize) with r in * by reflexivity.
+ cbv [M] in *.
+ assert (lower r[0~>1] = 0) by reflexivity.
+ repeat (ok_expr_step; [ ]).
+ ok_expr_step. { exact admit. (* TODO: the actual bounds on the second argument are lower, but relax_zrange steps have lost that information. *) }
+ repeat (ok_expr_step; [ ]).
+ ok_expr_step.
+ lazy; congruence.
+ constructor.
+ constructor. }
+ { lazy. omega. }
+ Qed.
+
+ Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) :=
+ Fancy.of_Expr 3%positive
+ (fun z => if z =? muLow then Some RegMuLow else if z =? M then Some RegMod else if z =? 0 then Some RegZero else None)
+ [M; muLow]
+ barrett_red256
+ (xLow, xHigh)%positive
+ (fun _ _ => tt)
+ error.
+ Derive barrett_red256_fancy
+ SuchThat (forall xLow xHigh RegMuLow RegMod RegZero,
+ barrett_red256_fancy xLow xHigh RegMuLow RegMod RegZero = barrett_red256_fancy' xLow xHigh RegMuLow RegMod RegZero)
+ As barrett_red256_fancy_eq.
+ Proof.
+ intros.
+ lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB Fancy.SUBC
+ Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU
+ Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM].
+ reflexivity.
+ Qed.
+
+ Import Fancy.Registers.
+
+ Definition barrett_red256_alloc' xLow xHigh RegMuLow :=
+ fun errorP errorR =>
+ Fancy.allocate register
+ positive Pos.eqb
+ errorR
+ (barrett_red256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP)
+ [r2;r3;r4;r5;r6;r7;r8;r9;r10;r5;r11;r6;r12;r13;r14;r15;r16;r17;r18;r19;r20;r21;r22;r23;r24;r25;r26;r27;r28;r29]
+ (fun n => if n =? 1000 then xLow
+ else if n =? 1001 then xHigh
+ else if n =? 1002 then RegMuLow
+ else if n =? 1003 then RegMod
+ else if n =? 1004 then RegZero
+ else errorR).
+ Derive barrett_red256_alloc
+ SuchThat (barrett_red256_alloc = barrett_red256_alloc')
+ As barrett_red256_alloc_eq.
+ Proof.
+ intros.
+ cbv [barrett_red256_alloc' barrett_red256_fancy].
+ cbn. subst barrett_red256_alloc.
+ reflexivity.
+ Qed.
+
+ Set Printing Depth 1000.
+ Import ProdEquiv.
+
+ Local Ltac solve_bounds :=
+ match goal with
+ | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega
+ | _ => assumption
+ end.
+
+ Lemma barrett_red256_alloc_equivalent errorP errorR cc_start_state start_context :
+ forall x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg,
+ NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] ->
+ 0 <= start_context x < 2^machine_wordsize ->
+ 0 <= start_context xHigh < 2^machine_wordsize ->
+ 0 <= start_context RegMuLow < 2^machine_wordsize ->
+ ProdEquiv.interp256 (barrett_red256_alloc r0 r1 r30 errorP errorR) cc_start_state
+ (fun r => if reg_eqb r r0
+ then start_context x
+ else if reg_eqb r r1
+ then start_context xHigh
+ else if reg_eqb r r30
+ then start_context RegMuLow
+ else start_context r)
+ = ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context.
+ Proof.
+ intros.
+ let r := eval compute in (2^machine_wordsize) in
+ replace (2^machine_wordsize) with r in * by reflexivity.
+ cbv [Prod.MulMod barrett_red256_alloc].
+
+ (* Extract proofs that no registers are equal to each other *)
+ repeat match goal with
+ | H : NoDup _ |- _ => inversion H; subst; clear H
+ | H : ~ In _ _ |- _ => cbv [In] in H
+ | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H
+ | H : ~ False |- _ => clear H
+ end.
+
+ step_both_sides.
+
+ (* TODO: To prove equivalence between these two, we need to either relocate the RSHI instructions so they're in the same places or use instruction commutativity to push them down. *)
+
+ Admitted.
+
Import PrintingNotations.
Set Printing Width 1000.
Open Scope expr_scope.
@@ -11120,10 +11497,6 @@ barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type
: Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z))
*)
- Definition muLow := (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize).
- Definition barrett_red256_prefancy :=
- Eval lazy in (PreFancy.of_Expr machine_wordsize [M;muLow] barrett_red256).
-
Import PreFancy.
Import PreFancy.Notations.
Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951).
@@ -11779,6 +12152,11 @@ Module Montgomery256.
Ltac ok_expr_step' :=
match goal with
| _ => assumption
+ | |- _ <= _ <= _ \/ @eq zrange _ _ =>
+ right; lazy; try split; congruence
+ | |- _ <= _ <= _ \/ @eq zrange _ _ =>
+ left; lazy; try split; congruence
+ | |- lower r[0~>_]%zrange = 0 => reflexivity
| |- context [PreFancy.ok_ident] => constructor
| |- context [PreFancy.ok_scalar] => constructor; try omega
| |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ]
@@ -11873,76 +12251,7 @@ Module Montgomery256.
reflexivity.
Qed.
- (* TODO : move *)
- Lemma mulll_comm rd x y cont cc ctx :
- ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx.
- Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
-
- Lemma mulhh_comm rd x y cont cc ctx :
- ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx.
- Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
-
- Lemma mullh_mulhl rd x y cont cc ctx :
- ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx.
- Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed.
-
- Lemma add_comm rd x y cont cc ctx :
- 0 <= ctx x < 2^256 ->
- 0 <= ctx y < 2^256 ->
- ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx.
- Proof.
- intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm.
- rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity.
- Qed.
-
- Lemma addc_comm rd x y cont cc ctx :
- 0 <= ctx x < 2^256 ->
- 0 <= ctx y < 2^256 ->
- ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx.
- Proof.
- intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)).
- rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity.
- Qed.
-
- SearchAbout reg_eqb.
-
- Local Ltac push_value_unused :=
- repeat match goal with
- | |- ~ In _ _ => cbn; intuition; congruence
- | _ => apply ProdEquiv.value_unused_overwrite
- | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ]
- | _ => apply ProdEquiv.value_unused_ret; congruence
- end.
-
- Local Ltac remember_single_result :=
- match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] =>
- let x := fresh "x" in
- let y := fresh "y" in
- let Heqx := fresh "Heqx" in
- remember (Fancy.spec i args cc) as x eqn:Heqx;
- remember (x mod w) as y
- end.
- Local Ltac step_both_sides :=
- match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 =>
- rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2);
- cbn - [Fancy.interp Fancy.spec];
- repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence;
- remember_single_result;
- lazymatch goal with
- | |- context [Fancy.spec i _ _] =>
- let Heqa1 := fresh in
- let Heqa2 := fresh in
- remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1;
- remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2;
- cbn in Heqa1; cbn in Heqa2;
- repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence;
- repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence;
- let a1 := match type of Heqa1 with _ = ?a1 => a1 end in
- let a2 := match type of Heqa2 with _ = ?a2 => a2 end in
- (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2)
- | _ => idtac
- end
- end.
+ Import ProdEquiv.
Local Ltac solve_bounds :=
match goal with
@@ -12170,143 +12479,19 @@ montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z *
End Montgomery256.
-(* Extra-specialized ad-hoc pretty-printing *)
-Module FancyPrintingNotations.
- Export ident.
- Open Scope expr_scope.
- Open Scope ctype_scope.
- Notation "'RegMod'" :=
- (AppIdent
- (primitive 115792089210356248762697446949407573530086143415290314195533631308867097853951)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'RegMuLow'" :=
- (AppIdent
- (primitive 26959946667150639793205513449348445388433292963828203772348655992835)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'RegPinv'" :=
- (AppIdent
- (primitive 115792089210356248768974548684794254293921932838497980611635986753331132366849)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'RegZero'" :=
- (AppIdent
- (primitive 0)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : expr_scope.
- Notation "'c.Lower(RegMod)'" :=
- (AppIdent
- (primitive 79228162514264337593543950335)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'c.Upper(RegMod)'" :=
- (AppIdent
- (primitive 340282366841710300967557013911933812736)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'c.Lower(RegMuLow)'" :=
- (AppIdent
- (primitive 340282366841710300930663525764514709507)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'c.Upper(RegMuLow)'" :=
- (AppIdent
- (primitive 79228162514264337589248983038)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'c.Lower(RegPinv)'" :=
- (AppIdent
- (primitive 79228162514264337593543950337)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'c.Upper(RegPinv)'" :=
- (AppIdent
- (primitive 340282366841710300986003757985643364352)
- TT) (only printing, at level 9) : expr_scope.
- Notation "'uint256'"
- := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : ctype_scope.
- Notation "'uint128'"
- := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : ctype_scope.
- Notation "$ n" := (Var n) (at level 10, format "$ n") : expr_scope.
- Notation "$ n" := (Z.cast _ @@ Var n) (at level 10, format "$ n") : expr_scope.
- Notation "$ n '_lo'" := (fst @@ (Var n))%expr (at level 10, format "$ n _lo") : expr_scope.
- Notation "$ n '_hi'" := (snd @@ (Var n))%expr (at level 10, format "$ n _hi") : expr_scope.
- Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Var n)))%expr (at level 10, format "$ n _lo") : expr_scope.
- Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Var n)))%expr (at level 10, format "$ n _hi") : expr_scope.
- Notation "$ n '_lo'" := (fst @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _lo") : expr_scope.
- Notation "$ n '_hi'" := (snd @@ (Z.cast2 _ @@ Var n))%expr (at level 10, format "$ n _hi") : expr_scope.
- Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _lo") : expr_scope.
- Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _hi") : expr_scope.
- Notation "'c.Mul128x128(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast uint256 @@ (Z.mul @@ (x, y)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' f" :=
- (expr_let n := Z.cast _ @@ (Z.shiftl count @@ (Z.cast uint256 @@ (Z.mul @@ (x, y)))) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : expr_scope.
- Notation "'c.Add256(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Add128(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Add64(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast uint128 @@ (Z.add @@ (x, y)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Addc256(' '$' n ',' x ',' y ',' z ');' f" :=
- (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc256(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
- Notation "'c.Addc128(' '$' n ',' x ',' y ',' z ');' f" :=
- (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (x, y, z)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc128(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
- Notation "'c.Selc(' '$' n ',' x ',' y ',' z ');' f" :=
- (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (x , y, z)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
- Notation "'c.Selm(' '$' n ',' x ',' y ',' z ');' f" :=
- (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0 ~>1]) @@ ((Z.cc_m_concrete _) @@ x), y, z)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selm(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
- Notation "'c.Sell(' '$' n ',' x ',' y ',' z ');' f" :=
- (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (Z.cast (r[0~>1]) @@ (Z.land 1 @@ _ x), y, z)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Sell(' '$' n ',' x ',' y ',' z ');' ']' '//' f") : expr_scope.
- Notation "'c.Sub(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in
- f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope.
- Notation "'c.Sub(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)) in
- f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope.
- Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" :=
- (Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope.
- Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast _ @@ (Z.shiftr y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Rshi(' '$' n ',' x ',' y ',' m ');' f" :=
- (expr_let n := Z.cast _ @@ (Z.rshi_concrete $R m @@ (x, y)) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Rshi(' '$' n ',' x ',' y ',' m ');' ']' '//' f") : expr_scope.
- Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" :=
- (expr_let n := Z.cast _ @@ (Z.shiftl y @@ (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x))) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
- Notation "'c.Lower128(' '$' n ',' x ');' f" :=
- (expr_let n := Z.cast _ @@ (Z.land 340282366920938463463374607431768211455 @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$' n ',' x ');' ']' '//' f") : expr_scope.
- Notation "'c.Lower(' x ')'"
- := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455 @@ x))
- (at level 10, only printing, format "c.Lower( x )")
- : expr_scope.
- Notation "'c.Upper(' x ')'"
- := (Z.cast uint128 @@ (Z.shiftr 128 @@ x))
- (at level 10, only printing, format "c.Upper( x )")
- : expr_scope.
- Notation "( v << count )"
- := (Z.cast _ @@ (Z.shiftl count @@ v)%expr)
- (format "( v << count )")
- : expr_scope.
- (*
- Notation "( x >> count )"
- := (Z.cast _ @@ (Z.shiftr count @@ x)%expr)
- (format "( x >> count )")
- : expr_scope.
- Notation "x * y"
- := (Z.cast uint256 @@ (Z.mul @@ (x, y)))
- : expr_scope.
- *)
-End FancyPrintingNotations.
-
-
Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont").
Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont").
Import Fancy.Registers.
Import Fancy.
+
+Import Barrett256 Montgomery256.
+
+(*** Montgomery Reduction ***)
+
+(* Status: Code in final form is proven correct modulo admits in compiler portions. *)
+
+(* Montgomery Code : *)
Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.MontRed256.
(*
= fun lo hi y t1 t2 scratch RegPInv : register =>
@@ -12330,46 +12515,107 @@ Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.M
ADDM lo lo RegZero RegMod;
Ret lo
*)
-Import Montgomery256.
+
+(* Uncomment to see proof statement and remaining admitted statements,
+or search for "prod_montred256_correct" to see comments on the proof
+preconditions. *)
+(*
Check Montgomery256.prod_montred256_correct.
-(* Print Assumptions Montgomery256.prod_montred256_correct. *)
+Print Assumptions Montgomery256.prod_montred256_correct.
+*)
+
+(*** Barrett Reduction ***)
-Import FancyPrintingNotations.
-Local Open Scope expr_scope.
+(* Status : Code in "pre-fancy" (stage before final form) is proven
+correct modulo admits in compiler portions and one additional
+admit (see comment in barrett_red256_prefancy_correct). Code in final
+form ("fancy") is generated, but not yet proven equivalent to the
+"pre-fancy" version.
+
+The next step is to, using the same methodology as for Montgomery,
+prove the generated fancy code equivalent to Prod.MulMod and also
+equivalent to the "pre-fancy" code. Once these equivalences are
+proven, then the "pre-fancy" proof will apply to Prod.MulMod.
+ *)
+
+Import PreFancy.
+Import PreFancy.Notations.
+Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951).
+Local Notation "'RegMuLow'" := (Straightline.expr.Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835).
+
+(* "Prefancy" form : *)
+Print Barrett256.barrett_red256_prefancy.
+(*
+selm@(y, $x₂, RegZero, RegMuLow);
+rshi@(y0, RegZero, $x₂,255);
+rshi@(y1, $x₂, $x₁,255);
+mulhh@(y2, RegMuLow, $y1);
+mulhl@(y3, RegMuLow, $y1);
+mullh@(y4, RegMuLow, $y1);
+mulll@(y5, RegMuLow, $y1);
+add@(y6, $y5, $y4, 128);
+addc@(y7, carry{$y6}, $y2, $y4, -128);
+add@(y8, $y6, $y3, 128);
+addc@(y9, carry{$y8}, $y7, $y3, -128);
+add@(y10, $y1, $y9, 0);
+addc@(y11, carry{$y10}, RegZero, $y0, 0);
+add@(y12, $y, $y10, 0);
+addc@(y13, carry{$y12}, RegZero, $y11, 0);
+rshi@(y14, $y13, $y12,1);
+mulhh@(y15, RegMod, $y14);
+mullh@(y16, RegMod, $y14);
+mulhl@(y17, RegMod, $y14);
+mulll@(y18, RegMod, $y14);
+add@(y19, $y18, $y17, 128);
+addc@(y20, carry{$y19}, $y15, $y17, -128);
+add@(y21, $y19, $y16, 128);
+addc@(y22, carry{$y21}, $y20, $y16, -128);
+sub@(y23, $x₁, $y21, 0);
+subb@(y24, carry{$y23}, $x₂, $y22, 0);
+sell@(y25, $y24, RegZero, RegMod);
+sub@(y26, $y23, $y25, 0);
+addm@(y27, $y26, RegZero, RegMod);
+ret $y27
+ *)
+
+(* Uncomment to see proof statement and remaining admitted statements. *)
+(*
+Check barrett_red256_prefancy_correct.
+Print Assumptions barrett_red256_prefancy_correct.
+ *)
-Print Barrett256.barrett_red256.
+(* "Fancy" code (NOT proven) *)
+Eval cbv beta iota delta [barrett_red256_alloc] in barrett_red256_alloc.
(*
-c.Selm($x0, $x_hi, RegZero, RegMuLow);
-c.Rshi($x1, RegZero, $x_hi, 255);
-c.Rshi($x2, $x_hi, $x_lo, 255);
-c.Mul128x128($x3, c.Upper(RegMuLow), c.Upper($x2));
-c.Mul128x128($x4, c.Upper(RegMuLow), c.Lower($x2));
-c.Mul128x128($x5, c.Lower(RegMuLow), c.Upper($x2));
-c.Mul128x128($x6, c.Lower(RegMuLow), c.Lower($x2));
-c.Add256($x7, (c.Lower($x4) << 128), $x6);
-c.Addc256($x8, $x7_hi, $x3, c.Upper($x5));
-c.Add256($x9, (c.Lower($x5) << 128), $x7_lo);
-c.Addc256($x10, $x9_hi, c.Upper($x4), $x8_lo);
-c.Add256($x11, $x2, $x10_lo);
-c.Addc128($x12, $x11_hi, RegZero, $x1);
-c.Add256($x13, $x0, $x11_lo);
-c.Addc128($x14, $x13_hi, RegZero, $x12_lo);
-c.Rshi($x15, $x14_lo, $x13_lo, 1);
-c.Mul128x128($x16, c.Upper(RegMod), c.Upper($x15));
-c.Mul128x128($x17, c.Lower(RegMod), c.Upper($x15));
-c.Mul128x128($x18, c.Upper(RegMod), c.Lower($x15));
-c.Mul128x128($x19, c.Lower(RegMod), c.Lower($x15));
-c.Add256($x20, (c.Lower($x17) << 128), $x19);
-c.Addc256($x21, $x20_hi, $x16, c.Upper($x18));
-c.Add256($x22, (c.Lower($x18) << 128), $x20_lo);
-c.Addc256($x23, $x22_hi, c.Upper($x17), $x21_lo);
-expr_let x24 := Z.add_get_carry_concrete
- 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@
- (Z.opp @@ $x22_lo, $x_lo) in
-expr_let x25 := Z.add_with_get_carry_concrete
- 115792089237316195423570985008687907853269984665640564039457584007913129639936 @@
- ($x24_hi, Z.opp @@ $x23_lo, $x_hi) in
-c.Sell($x26, $x25_lo, RegZero, RegMod);
-c.Sub($x27, $x24_lo, $x26);
-c.AddM($ret, $x27, RegZero, RegMod);
+ = fun (xLow xHigh RegMuLow : register) (_ : positive) (_ : register) =>
+ SELM r2 RegMuLow RegZero;
+ RSHI 255 r3 RegZero xHigh;
+ RSHI 255 r4 xHigh xLow;
+ MUL128UU r5 RegMuLow r4;
+ MUL128UL r6 r4 RegMuLow;
+ MUL128LU r7 r4 RegMuLow;
+ MUL128LL r8 RegMuLow r4;
+ ADD 128 r9 r8 r7;
+ ADDC (-128) r10 r5 r7;
+ ADD 128 r5 r9 r6;
+ ADDC (-128) r11 r10 r6;
+ ADD 0 r6 r4 r11;
+ ADDC 0 r12 RegZero r3;
+ ADD 0 r13 r2 r6;
+ ADDC 0 r14 RegZero r12;
+ RSHI 1 r15 r14 r13;
+ MUL128UU r16 RegMod r15;
+ MUL128LU r17 r15 RegMod;
+ MUL128UL r18 r15 RegMod;
+ MUL128LL r19 RegMod r15;
+ ADD 128 r20 r19 r18;
+ ADDC (-128) r21 r16 r18;
+ ADD 128 r22 r20 r17;
+ ADDC (-128) r23 r21 r17;
+ SUB 0 r24 xLow r22;
+ SUBC 0 r25 xHigh r23;
+ SELL r26 RegMod RegZero;
+ SUB 0 r27 r24 r26;
+ ADDM r28 r27 RegZero RegMod;
+ Ret r28
*)