diff options
author | 2018-03-10 20:38:19 -0500 | |
---|---|---|
committer | 2018-03-19 14:17:26 -0400 | |
commit | e8c273b1d8f8a4e11c65868946c2461d61638127 (patch) | |
tree | d31b40ca9e3f058eeb02d0981fdbcfc99552d338 /src | |
parent | 5c598ea988a291a2d861d3f626e0ffd61701ff72 (diff) |
Clean and simplify some code
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 1000 |
1 files changed, 320 insertions, 680 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index bddbcee7f..854993ef4 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -10,6 +10,7 @@ Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil. Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. Require Import Crypto.Arithmetic.PrimeFieldTheorems. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.ZRange. @@ -25,6 +26,7 @@ Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.Tactics.SplitInContext. Require Import Crypto.Util.Notations. Require Import Crypto.Util.ZUtil.Definitions. Import ListNotations. Local Open Scope Z_scope. @@ -1032,6 +1034,13 @@ Module Ring. Local Notation is_bounded_by2 bounds ls := (let '(a, b) := ls in andb (is_bounded_by bounds a) (is_bounded_by bounds b)). + Lemma length_is_bounded_by bounds ls + : is_bounded_by bounds ls = true -> length ls = length bounds. + Proof. + intro H. + apply fold_andb_map_length in H; congruence. + Qed. + Section ring_goal. Context (weight : nat -> Z) (weight_0 : weight 0%nat = 1) @@ -1134,221 +1143,35 @@ Module Ring. -> is_bounded_by tight_bounds (Interp_rencodev arg) = true /\ Interp_rencodev arg = encodemod arg). - Definition encodedT - := { ls : list Z | is_bounded_by tight_bounds ls = true }. + Local Notation T := (list Z) (only parsing). + Local Notation encoded_ok ls + := (is_bounded_by tight_bounds ls = true) (only parsing). + Local Notation encoded_okf := (fun ls => encoded_ok ls) (only parsing). - Definition list_of_encodedT (v : encodedT) : list _ - := proj1_sig v. - - Lemma length_list_of_encodedT v : List.length (list_of_encodedT v) = n. - Proof. - destruct v as [v H]; cbn in H |- *. - apply fold_andb_map_length in H; rewrite <- H. - assumption. - Qed. - - Definition Zdecode (v : encodedT) - := list_of_encodedT v. - Definition Fdecode (v : encodedT) : F m - := F.of_Z m (Positional.eval weight n (Zdecode v)). - Definition encodedT_eq (x y : encodedT) + Definition Fdecode (v : T) : F m + := F.of_Z m (Positional.eval weight n v). + Definition T_eq (x y : T) := Fdecode x = Fdecode y. - Lemma length_Zdecode v : List.length (Zdecode v) = n. - Proof. apply length_list_of_encodedT. Qed. - - - Local Ltac specialize_from_interp _ := - repeat match goal with - | [ H := ?Interp_rv ?arg, Hc : context[?Interp_rv] |- _ ] - => unique pose proof (Hc arg) - end. - Local Ltac specialize_with_bounded _ := - repeat match goal with - | _ => progress cbv [Zdecode list_of_encodedT] in * - | _ => progress cbn [proj1_sig fst snd] in * - | [ H : context[andb ?x ?y = true] |- _ ] - => rewrite (Bool.andb_true_iff x y) in H - | [ H : _ /\ _ -> _ |- _ ] => specialize (fun a b => H (conj a b)) - | _ => progress destruct_head'_and - | [ H : ?f (proj1_sig ?v) = true -> _ |- _ ] - => specialize (H (proj2_sig v)); hnf in H - | [ H : ?x = true -> _, H' : ?x = true |- _ ] - => specialize (H H'); hnf in H - | [ H : ?x = true |- { pf : ?x = true | _ } ] => exists H - end. - Local Ltac rewrite_interp _ := - repeat match goal with - | [ H := _ |- _ ] => subst H - | [ H : ?x = _ |- context[?x] ] => rewrite H - | _ => rewrite expanding_id_id - | [ |- ?x = ?x ] => reflexivity - | [ |- List.length (proj1_sig _) = _ ] => apply length_list_of_encodedT - end. - - Local Ltac solve_encodedT _ := - specialize_from_interp (); - specialize_with_bounded (); - rewrite_interp (). - - Definition ring_mul_sig - : forall (x y : encodedT), - { v : encodedT - | Zdecode v = carry_mulmod (Zdecode x, Zdecode y) }. - Proof. - simple refine - (fun x y - => let x' := Interp_rrelaxv (proj1_sig x) in - let y' := Interp_rrelaxv (proj1_sig y) in - let v' := Interp_rcarry_mulv (x', y') in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - clear dependent HInterp_rzerov; clear dependent HInterp_ronev. - abstract solve_encodedT (). - Defined. - - Definition ring_add_sig - : forall (x y : encodedT), - { v : encodedT - | Zdecode v = carrymod (addmod (Zdecode x, Zdecode y)) }. - Proof. - simple refine - (fun x y - => let v'' := Interp_raddv (proj1_sig x, proj1_sig y) in - let v' := Interp_rcarryv v'' in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - clear dependent HInterp_rzerov; clear dependent HInterp_ronev. - abstract solve_encodedT (). - Defined. - Definition ring_sub_sig - : forall (x y : encodedT), - { v : encodedT - | Zdecode v = carrymod (submod (Zdecode x, Zdecode y)) }. - Proof. - simple refine - (fun x y - => let v'' := Interp_rsubv (proj1_sig x, proj1_sig y) in - let v' := Interp_rcarryv v'' in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - clear dependent HInterp_rzerov; clear dependent HInterp_ronev. - abstract solve_encodedT (). - Defined. - Definition ring_opp_sig - : forall (x : encodedT), - { v : encodedT - | Zdecode v = carrymod (oppmod (Zdecode x)) }. - Proof. - simple refine - (fun x - => let v'' := Interp_roppv (proj1_sig x) in - let v' := Interp_rcarryv v'' in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - clear dependent HInterp_rzerov; clear dependent HInterp_ronev. - abstract solve_encodedT (). - Defined. - Definition ring_zero_sig - : { v : encodedT - | Zdecode v = zeromod }. - Proof. - simple refine - (let v' := Interp_rzerov in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - cbn; clear -HInterp_rzerov. - abstract (destruct HInterp_rzerov; constructor; assumption). - Defined. - Definition ring_one_sig - : { v : encodedT - | Zdecode v = onemod }. - Proof. - simple refine - (let v' := Interp_ronev in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - cbn; clear -HInterp_ronev. - abstract (destruct HInterp_ronev; cbn; constructor; assumption). - Defined. - Arguments Z.mul !_ !_ . - Definition ring_encode_sig - : forall (x : F m), - { v : encodedT - | Zdecode v = encodemod (F.to_Z x) }. - Proof. - simple refine - (fun v - => let pf1 := _ in - let arg := F.to_Z v in - let Hrencodev' := HInterp_rencodev arg pf1 in - let v' := Interp_rencodev arg in - let pf : { pf0 : _ | Zdecode (exist _ v' pf0) = _ } - := _ in - exist _ (exist _ v' (proj1_sig pf)) (proj2_sig pf)). - { generalize m_eq. - generalize sc_pos. - cbn; clear -v. - abstract ( - intros sc_pos m_eq; - destruct v as [v Hv]; cbn; rewrite m_eq in Hv; - pose proof (Z.mod_pos_bound v _ sc_pos); - rewrite Bool.andb_true_iff; split; Z.ltb_to_lt; lia - ). } - { cbv [Zdecode list_of_encodedT]; cbn [proj1_sig]. - subst v' arg; clearbody Hrencodev'; cbv beta zeta in *. - clear -Hrencodev'. - abstract (constructor; apply Hrencodev'). } - Defined. - - Lemma length_addmod x y - : List.length (addmod (Zdecode x, Zdecode y)) = n. - Proof. - destruct (HInterp_raddv (Zdecode x, Zdecode y)) as [H0 H1]; - [ .. | apply fold_andb_map_length in H0; rewrite <- H1, <- H0; assumption ]. - cbv [Zdecode fst snd list_of_encodedT]. - rewrite (proj2_sig x), (proj2_sig y); reflexivity. - Qed. - Lemma length_submod x y - : List.length (submod (Zdecode x, Zdecode y)) = n. - Proof. - destruct (HInterp_rsubv (Zdecode x, Zdecode y)) as [H0 H1]; - [ .. | apply fold_andb_map_length in H0; rewrite <- H1, <- H0; assumption ]. - cbv [Zdecode fst snd list_of_encodedT]. - rewrite (proj2_sig x), (proj2_sig y); reflexivity. - Qed. - Lemma length_oppmod x - : List.length (oppmod (Zdecode x)) = n. - Proof. - destruct (HInterp_roppv (Zdecode x)) as [H0 H1]; - [ .. | apply fold_andb_map_length in H0; rewrite <- H1, <- H0; assumption ]. - cbv [Zdecode fst snd list_of_encodedT]. - rewrite (proj2_sig x); reflexivity. - Qed. + Definition encodedT := sig encoded_okf. - Definition ring_mul x y := proj1_sig (ring_mul_sig x y). - Definition ring_add x y := proj1_sig (ring_add_sig x y). - Definition ring_sub x y := proj1_sig (ring_sub_sig x y). - Definition ring_opp x := proj1_sig (ring_opp_sig x). - Definition ring_zero := proj1_sig ring_zero_sig. - Definition ring_one := proj1_sig ring_one_sig. - Definition ring_encode x := proj1_sig (ring_encode_sig x). + Definition ring_mul (x y : T) : T + := Interp_rcarry_mulv (Interp_rrelaxv x, Interp_rrelaxv y). + Definition ring_add (x y : T) : T := Interp_rcarryv (Interp_raddv (x, y)). + Definition ring_sub (x y : T) : T := Interp_rcarryv (Interp_rsubv (x, y)). + Definition ring_opp (x : T) : T := Interp_rcarryv (Interp_roppv x). + Definition ring_encode (x : F m) : T := Interp_rencodev (F.to_Z x). Definition GoodT : Prop - := @Hierarchy.ring - encodedT encodedT_eq ring_zero ring_one ring_opp ring_add ring_sub ring_mul - /\ @Ring.is_homomorphism - (F m) eq 1%F F.add F.mul - encodedT encodedT_eq ring_one ring_add ring_mul ring_encode - /\ @Ring.is_homomorphism - encodedT encodedT_eq ring_one ring_add ring_mul - (F m) eq 1%F F.add F.mul + := @subsetoid_ring + (list Z) encoded_okf T_eq + Interp_rzerov Interp_ronev ring_opp ring_add ring_sub ring_mul + /\ @is_subsetoid_homomorphism + (F m) (fun _ => True) eq 1%F F.add F.mul + (list Z) encoded_okf T_eq Interp_ronev ring_add ring_mul ring_encode + /\ @is_subsetoid_homomorphism + (list Z) encoded_okf T_eq Interp_ronev ring_add ring_mul + (F m) (fun _ => True) eq 1%F F.add F.mul Fdecode. Hint Rewrite ->@F.to_Z_add : push_FtoZ. @@ -1356,41 +1179,46 @@ Module Ring. Hint Rewrite ->@F.to_Z_opp : push_FtoZ. Hint Rewrite ->@F.to_Z_of_Z : push_FtoZ. - Local Ltac rewrite_proj2_sig _ := - lazymatch goal with - | [ |- context[proj1_sig ?x] ] => rewrite (proj2_sig x) - end. + Lemma Fm_bounded_alt (x : F m) + : (0 <=? F.to_Z x) && (F.to_Z x <=? 'm - 1) = true. + Proof using m_eq. + clear -m_eq. + destruct x as [x H]; cbn [F.to_Z proj1_sig]. + pose proof (Z.mod_pos_bound x ('m)). + rewrite andb_true_iff; split; Z.ltb_to_lt; lia. + Qed. Lemma Good : GoodT. Proof. - eapply ring_by_isomorphism; intros; [ | reflexivity | .. ]; rewrite F.eq_to_Z_iff; - cbv [F.sub Fdecode]; - autorewrite with push_FtoZ; - pull_Zmod; - rewrite ?Z.add_opp_r. - { cbv [ring_encode]; rewrite_proj2_sig (). - erewrite m_eq, Hencodemod, <- m_eq by assumption. - let A := lazymatch goal with A : F _ |- _ => A end in - destruct A as [v Hv]; cbn; congruence. } - { cbv [ring_zero]; rewrite_proj2_sig (). - erewrite m_eq, Hzeromod, Zmod_0_l by eassumption; reflexivity. } - { cbv [ring_one]; rewrite_proj2_sig (). - cbn; erewrite m_eq, Honemod; reflexivity. } - { cbv [ring_opp]; rewrite_proj2_sig (). - erewrite m_eq, Hcarrymod, Hoppmod - by eauto using length_Zdecode, length_oppmod; - reflexivity. } - { cbv [ring_add]; rewrite_proj2_sig (). - erewrite m_eq, Hcarrymod, Haddmod - by eauto using length_Zdecode, length_addmod; - reflexivity. } - { cbv [ring_sub]; rewrite_proj2_sig (). - erewrite m_eq, Hcarrymod, Hsubmod - by eauto using length_Zdecode, length_submod; - reflexivity. } - { cbv [ring_mul]; rewrite_proj2_sig (). - erewrite m_eq, Hcarry_mulmod - by eauto using length_Zdecode; reflexivity. } + split_and. + eapply subsetoid_ring_by_ring_isomorphism; + cbv [ring_opp ring_add ring_sub ring_mul ring_encode F.sub] in *; + repeat match goal with + | _ => solve [ auto using andb_true_intro, conj with nocore ] + | _ => progress intros + | _ => progress cbn [fst snd] + | [ H : _ |- is_bounded_by _ _ = true ] => apply H + | [ |- _ <-> _ ] => reflexivity + | [ |- _ = _ :> Z ] => first [ reflexivity | rewrite <- m_eq; reflexivity ] + | [ H : context[?x] |- Fdecode ?x = _ ] => rewrite H + | [ H : context[?x _] |- Fdecode (?x _) = _ ] => rewrite H + | _ => progress cbv [Fdecode] + | [ |- _ = _ :> F _ ] => apply F.eq_to_Z_iff + | _ => progress autorewrite with push_FtoZ + | _ => rewrite m_eq + | [ H : context[?x _] |- context[eval (?x _)] ] => rewrite H + | [ H : context[?x] |- context[eval ?x] ] => rewrite H + | [ |- context[List.length ?x] ] + => erewrite (length_is_bounded_by _ x) + by eauto using andb_true_intro, conj with nocore + | [ |- _ = _ :> Z ] + => push_Zmod; reflexivity + | _ => pull_Zmod; rewrite Z.add_opp_r + | _ => rewrite expanding_id_id + | [ |- context[F.to_Z _ mod (_ - _)] ] + => rewrite <- m_eq, F.mod_to_Z + | _ => rewrite <- m_eq; apply Fm_bounded_alt + end. Qed. End ring_goal. End Ring. @@ -1487,11 +1315,19 @@ Module Compilers. | Pair {A B} (a : expr A) (b : expr B) : expr (A * B) | Abs {s d} (f : var s -> expr d) : expr (s -> d). + Definition Expr {ident : type -> type -> Type} t := forall var, @expr ident var t. + + Definition APP {ident s d} (f : Expr (s -> d)) (x : Expr s) : Expr d + := fun var => @App ident var s d (f var) (x var). + Module Export Notations. Bind Scope expr_scope with expr. Delimit Scope expr_scope with expr. + Bind Scope Expr_scope with Expr. + Delimit Scope Expr_scope with Expr. Infix "@" := App : expr_scope. + Infix "@" := APP : Expr_scope. Infix "@@" := AppIdent : expr_scope. Notation "( x , y , .. , z )" := (Pair .. (Pair x%expr y%expr) .. z%expr) : expr_scope. Notation "( )" := TT : expr_scope. @@ -1499,8 +1335,6 @@ Module Compilers. Notation "'λ' x .. y , t" := (Abs (fun x => .. (Abs (fun y => t%expr)) ..)) : expr_scope. End Notations. - Definition Expr {ident : type -> type -> Type} t := forall var, @expr ident var t. - Section unexpr. Context {ident : type -> type -> Type} {var : type -> Type}. @@ -1531,6 +1365,10 @@ Module Compilers. end. Definition Interp {t} (e : Expr t) := interp (e _). + + Definition Interp_APP {s d} (f : @Expr ident (s -> d)) (x : @Expr ident s) + : Interp (f @ x)%Expr = Interp f (Interp x) + := eq_refl. End with_ident. Ltac require_primitive_const term := @@ -4891,6 +4729,7 @@ Module Compilers. Definition partial_reduce_with_bounds1 {s d} (e : @expr (partial.value var) (s -> d)) (b : ZRange.type.interp s) := partial.expr.reify (@partial_reduce_with_bounds1' s d e b). + End partial_reduce. Definition PartialReduce {t} (e : Expr t) : Expr t @@ -5556,7 +5395,7 @@ Module Pipeline. end. Definition BoundsPipeline - (with_dead_code_elimination : bool := true) + (with_dead_code_elimination : bool) relax_zrange {s d} (E : Expr (s -> d)) @@ -5564,6 +5403,9 @@ Module Pipeline. out_bounds : ErrorT (Expr (s -> d)) := let E := PartialReduce E in + (* Note that DCE evaluates the expr with two different [var] + arguments, and so will likely result in a pipeline that is + 2x slower *) let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in let E := ReassociateSmallConstants.Reassociate (2^8) E in let E := CheckedPartialReduceWithBounds1 relax_zrange E arg_bounds out_bounds in @@ -5574,7 +5416,7 @@ Module Pipeline. E. Lemma BoundsPipeline_correct - (*with_dead_code_elimination : bool*) + (with_dead_code_elimination : bool) relax_zrange (Hrelax : forall r r' z : zrange, (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) @@ -5583,13 +5425,13 @@ Module Pipeline. arg_bounds out_bounds rv - (Hrv : BoundsPipeline (*with_dead_code_elimination*) relax_zrange E arg_bounds out_bounds = Success rv) + (Hrv : BoundsPipeline with_dead_code_elimination relax_zrange E arg_bounds out_bounds = Success rv) : forall arg (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), ZRange.type.is_bounded_by out_bounds (Interp rv arg) = true /\ Interp rv arg = Interp E arg. Proof. - cbv [BoundsPipeline] in *; edestruct (CheckedPartialReduceWithBounds1 _ _ _ _) eqn:H. + cbv [BoundsPipeline Let_In] in *; edestruct (CheckedPartialReduceWithBounds1 _ _ _ _) eqn:H. inversion Hrv; subst. { intros; eapply CheckedPartialReduceWithBounds1_Correct in H; [ | eassumption.. ]. destruct H as [H0 H1]. @@ -5598,14 +5440,50 @@ Module Pipeline. { congruence. } Qed. + Definition BoundsPipeline_correct_transT + {s d} + (E : Expr (s -> d)) + arg_bounds + out_bounds + (InterpE : type.interp s -> type.interp d) + (rv : Expr (s -> d)) + := forall arg + (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), + ZRange.type.is_bounded_by out_bounds (Interp rv arg) = true + /\ Interp rv arg = InterpE arg. + + Lemma BoundsPipeline_correct_trans + (with_dead_code_elimination : bool) + relax_zrange + (Hrelax + : forall r r' z : zrange, + (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) + {s d} + (E : Expr (s -> d)) + arg_bounds out_bounds + (InterpE : type.interp s -> type.interp d) + (InterpE_correct + : forall arg + (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), + Interp E arg = InterpE arg) + rv (Hrv : BoundsPipeline with_dead_code_elimination relax_zrange E arg_bounds out_bounds = Success rv) + : BoundsPipeline_correct_transT E arg_bounds out_bounds InterpE rv. + Proof. + intros arg Harg; rewrite <- InterpE_correct by assumption. + eapply @BoundsPipeline_correct; eassumption. + Qed. + Definition BoundsPipelineConst - (with_dead_code_elimination : bool := true) + (with_dead_code_elimination : bool) relax_zrange {t} (E : Expr t) bounds : ErrorT (Expr t) := let E := PartialReduce E in + (* Note that DCE evaluates the expr with two different [var] + arguments, and so will likely result in a pipeline that is + 2x slower *) let E := if with_dead_code_elimination then DeadCodeElimination.EliminateDead E else E in let E := ReassociateSmallConstants.Reassociate (2^8) E in let E := CheckedPartialReduceWithBounds0 relax_zrange E bounds in @@ -5616,7 +5494,7 @@ Module Pipeline. E. Lemma BoundsPipelineConst_correct - (*with_dead_code_elimination : bool*) + (with_dead_code_elimination : bool) relax_zrange (Hrelax : forall r r' z : zrange, (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) @@ -5624,11 +5502,11 @@ Module Pipeline. (E : Expr d) bounds rv - (Hrv : BoundsPipelineConst (*with_dead_code_elimination*) relax_zrange E bounds = Success rv) + (Hrv : BoundsPipelineConst with_dead_code_elimination relax_zrange E bounds = Success rv) : ZRange.type.is_bounded_by bounds (Interp rv) = true /\ Interp rv = Interp E. Proof. - cbv [BoundsPipelineConst] in *; edestruct (CheckedPartialReduceWithBounds0 _ _ _) eqn:H. + cbv [BoundsPipelineConst Let_In] in *; edestruct (CheckedPartialReduceWithBounds0 _ _ _) eqn:H. inversion Hrv; subst. { intros; eapply CheckedPartialReduceWithBounds0_Correct in H; [ | eassumption.. ]. destruct H as [H0 H1]. @@ -5636,6 +5514,34 @@ Module Pipeline. exact admit. (* interp correctness *) } { congruence. } Qed. + + Definition BoundsPipelineConst_correct_transT + {t} + (E : Expr t) + out_bounds + (InterpE : type.interp t) + (rv : Expr t) + := ZRange.type.is_bounded_by out_bounds (Interp rv) = true + /\ Interp rv = InterpE. + + Lemma BoundsPipelineConst_correct_trans + (with_dead_code_elimination : bool) + relax_zrange + (Hrelax + : forall r r' z : zrange, + (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true) + {t} + (E : Expr t) + out_bounds + (InterpE : type.interp t) + (InterpE_correct : Interp E = InterpE) + rv + (Hrv : BoundsPipelineConst with_dead_code_elimination relax_zrange E out_bounds = Success rv) + : BoundsPipelineConst_correct_transT E out_bounds InterpE rv. + Proof. + rewrite <- InterpE_correct. + eapply @BoundsPipelineConst_correct; eassumption. + Qed. End Pipeline. Definition round_up_bitwidth_gen (possible_values : list Z) (bitwidth : Z) : option Z @@ -5797,401 +5703,146 @@ Section rcarry_mul. then Pipeline.Error (Pipeline.Value_not_lt "0 < machine_wordsize" 0 machine_wordsize) else res. - Lemma check_args_success_id {T} {rv : T} {res} - : check_args res = Pipeline.Success rv - -> res = Pipeline.Success rv. - Proof. - cbv [check_args]; break_innermost_match; congruence. - Qed. - - Local Ltac solve_correct_gen pipeline_lem gen_correct := - let Hrv := lazymatch goal with H : ?rop = Pipeline.Success _ |- _ => H end in - let rop := lazymatch type of Hrv with ?rop = Pipeline.Success _ => rop end in - hnf; intros; cbv [rop] in Hrv; - eapply pipeline_lem in Hrv; [ | eassumption.. ]; - let Hrv' := fresh "Hrv'" in - destruct Hrv as [Hrv Hrv']; - apply conj; [ exact Hrv | rewrite Hrv' ]; - repeat match goal with H := _ |- _ => subst H end; - erewrite <- gen_correct; - cbv [expr.Interp]; - cbn [expr.interp]; - f_equal; - cbn -[reify_list]; - try (rewrite interp_reify_list, map_map; cbn; - erewrite map_ext with (g:=id), map_id; try reflexivity); - try (intros []; reflexivity). - Local Ltac solve_correct gen_correct := - solve_correct_gen Pipeline.BoundsPipeline_correct gen_correct. - Local Ltac solve_correct_const gen_correct := - solve_correct_gen Pipeline.BoundsPipelineConst_correct gen_correct. - - (* TODO(jgross): open bug about sensitivity of order of arguments on type inference with interp, expr *) - (* TODO(jgross): Make @ apply to Expr, not just expr *) + Local Ltac t_solve_interp := + try solve [ reflexivity + | cbn -[reify_list]; + try (rewrite interp_reify_list, map_map; cbn; + erewrite map_ext with (g:=id), map_id; try reflexivity); + try (intros []; reflexivity) ]. + Lemma Interp_rs : Interp rs = s. Proof. reflexivity. Qed. + Lemma Interp_rc : Interp rc = c. Proof. t_solve_interp. Qed. + Lemma Interp_rn : Interp rn = n. Proof. reflexivity. Qed. + Lemma Interp_ridxs : Interp ridxs = idxs. Proof. t_solve_interp. Qed. + + Local Hint Rewrite @Interp_APP : interp_correct. + Local Hint Rewrite Interp_rs Interp_rc Interp_rn Interp_ridxs : interp_correct. + Local Hint Rewrite carry_mul_gen_correct carry_gen_correct id_gen_correct add_gen_correct sub_gen_correct opp_gen_correct encode_gen_correct zero_gen_correct one_gen_correct : interp_correct. + + Local Ltac do_interp_correct := + intros; repeat autorewrite with interp_correct; reflexivity. + + Notation type_of := ((fun T (_ : T) => T) _). + Notation type_of_strip_arrow := ((fun s (d : Prop) (_ : s -> d) => d) _ _). + + Notation BoundsPipeline rop in_bounds out_bounds + := (Pipeline.BoundsPipeline + false + relax_zrange + rop%Expr in_bounds out_bounds). + + Notation BoundsPipelineConst rop out_bounds + := (Pipeline.BoundsPipelineConst + false + relax_zrange + rop%Expr out_bounds). + + Notation BoundsPipeline_correct rop in_bounds out_bounds op + := (@Pipeline.BoundsPipeline_correct_trans + false + relax_zrange + relax_zrange_good + _ _ + rop%Expr + in_bounds + out_bounds + op + ltac:(do_interp_correct)). + + Notation BoundsPipelineConst_correct rop out_bounds op + := (@Pipeline.BoundsPipelineConst_correct_trans + false + relax_zrange + relax_zrange_good + _ + rop%Expr + out_bounds + op + ltac:(do_interp_correct)). + + (* N.B. We only need [rcarry_mul] if we want to extract the Pipeline; otherwise we can just use [rcarry_mul_correct] *) Definition rcarry_mul - := let res := Pipeline.BoundsPipeline - relax_zrange - (fun var - => (carry_mul_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - @ (ridxs _) - @ (rlen_idxs _) - )%expr - (loose_bounds, loose_bounds) - tight_bounds in - res. - - Definition rcarry_mul_correctT - (rv : Expr (type.list type.Z * type.list type.Z -> type.list type.Z)) - := (forall - arg - (Harg : ZRange.type.is_bounded_by (t:=type.prod (type.list type.Z) (type.list type.Z)) - (loose_bounds, loose_bounds) arg = true), - ZRange.type.is_bounded_by (t:=type.list type.Z) - tight_bounds - (Interp rv arg) = true /\ - Interp rv arg = - carry_mulmod (Interp rw) s c n (Interp rlen_c) - idxs (Interp rlen_idxs) arg). - Check (@Pipeline.BoundsPipeline_correct - relax_zrange - _ - _ _ - ). - - Lemma rcarry_mul_correct - rv (Hrv : rcarry_mul = Pipeline.Success rv) - : rcarry_mul_correctT rv. - Proof. solve_correct carry_mul_gen_correct. Qed. - - - - Let BoundsPipeline21 in_bounds out_bounds res - := let res := Pipeline.BoundsPipeline - relax_zrange - (s:=(type.list type.Z * type.list type.Z)%ctype) - (d:=(type.list type.Z)%ctype) - res - (in_bounds, in_bounds) - out_bounds in - res. - - Let BoundsPipeline11 in_bounds out_bounds res - := let res := Pipeline.BoundsPipeline - relax_zrange - (s:=(type.list type.Z)%ctype) - (d:=(type.list type.Z)%ctype) - res - (in_bounds) - out_bounds in - res. - - Definition rexpr_1_correctT_Interp - {t} - t_out - (Interp : Expr t -> _) - out_bounds - (f : type.interp t_out) - rv - := (ZRange.type.is_bounded_by out_bounds (Interp rv) = true - /\ Interp rv = f). - - Definition rexpr_n1_correctT - t_in t_out - in_bounds out_bounds - (f : type.interp t_in -> type.interp t_out) - rv - := forall arg - (Harg : ZRange.type.is_bounded_by in_bounds arg = true), - @rexpr_1_correctT_Interp (t_in -> t_out) t_out (fun rv => Interp rv arg) out_bounds (f arg) rv. - - Definition rexpr_1_correctT - t_out - out_bounds - (f : type.interp t_out) - rv - := @rexpr_1_correctT_Interp t_out t_out Interp out_bounds f rv. - - Definition rexpr_21_correctT - in_bounds out_bounds - (f : _ -> type.interp (type.list type.Z)) - rv - := @rexpr_n1_correctT (type.list type.Z * type.list type.Z) - (type.list type.Z) - (in_bounds, in_bounds) out_bounds f rv. - - Definition rexpr_11_correctT - in_bounds out_bounds - (f : _ -> type.interp (type.list type.Z)) - rv - := @rexpr_n1_correctT (type.list type.Z) - (type.list type.Z) - in_bounds out_bounds f rv. - - Definition rexpr_Z1_correctT - in_bounds out_bounds - (f : _ -> type.interp (type.list type.Z)) - rv - := @rexpr_n1_correctT type.Z (type.list type.Z) - in_bounds out_bounds f rv. - - Definition rexpr_01_correctT - out_bounds - (f : type.interp (type.list type.Z)) - rv - := @rexpr_1_correctT (type.list type.Z) out_bounds f rv. - - Definition rcarry_mul_correctT - rv - := Eval hnf in - rexpr_21_correctT - loose_bounds tight_bounds - (carry_mulmod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)) - rv. - Eval cbv [rcarry_mul_correctT rexpr_1_correctT_Interp] in rcarry_mul_correctT. - Print rcarry_mul_correctT. - - Lemma rcarry_mul_correct - rv (Hrv : rcarry_mul = Pipeline.Success rv) - : rcarry_mul_correctT rv. - Proof. solve_correct carry_mul_gen_correct. Qed. - - Definition rcarry - := let res := Pipeline.BoundsPipeline - false - relax_zrange - (s:=(type.list type.Z)%ctype) - (d:=(type.list type.Z)%ctype) - loose_bounds - tight_bounds - (fun var - => (carry_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - @ (ridxs _) - @ (rlen_idxs _) - )%expr in - res. - - Definition rcarry_correctT - rv - := Eval hnf in - rexpr_11_correctT - loose_bounds tight_bounds - (carrymod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)) - rv. - - Lemma rcarry_correct - rv (Hrv : rcarry = Pipeline.Success rv) - : rcarry_correctT rv. - Proof. solve_correct carry_gen_correct. Qed. - - Definition rrelax - := let res := Pipeline.BoundsPipeline - false - relax_zrange - (s:=(type.list type.Z)%ctype) - (d:=(type.list type.Z)%ctype) - tight_bounds - loose_bounds - (fun var - => (id_gen _) - @ (rn _) - )%expr in - res. - - Definition rrelax_correctT - rv - := Eval hnf in - rexpr_11_correctT - tight_bounds loose_bounds - (expanding_id n) - rv. - - Lemma rrelax_correct - rv (Hrv : rrelax = Pipeline.Success rv) - : rrelax_correctT rv. - Proof. solve_correct id_gen_correct. Qed. - - Definition radd - := let res := BoundsPipeline21 - tight_bounds - loose_bounds - (fun var - => (add_gen _) - @ (rw _) - @ (rn _) - )%expr in - res. - - Definition radd_correctT - rv - := Eval hnf in - rexpr_21_correctT - tight_bounds loose_bounds - (addmod (Interp rw) n) - rv. - - Lemma radd_correct - rv (Hrv : radd = Pipeline.Success rv) - : radd_correctT rv. - Proof. solve_correct add_gen_correct. Qed. - - Definition rsub - := let res := BoundsPipeline21 - tight_bounds - loose_bounds - (fun var - => (sub_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - @ (rcoef _) - )%expr in - res. - - Definition rsub_correctT - rv - := Eval hnf in - rexpr_21_correctT - tight_bounds loose_bounds - (submod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)) - rv. - - Lemma rsub_correct - rv (Hrv : rsub = Pipeline.Success rv) - : rsub_correctT rv. - Proof. solve_correct sub_gen_correct. Qed. - - Definition ropp - := let res := BoundsPipeline11 - tight_bounds - loose_bounds - (fun var - => (opp_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - @ (rcoef _) - )%expr in - res. - - Definition ropp_correctT - rv - := Eval hnf in - rexpr_11_correctT - tight_bounds loose_bounds - (oppmod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)) - rv. - - Lemma ropp_correct - rv (Hrv : ropp = Pipeline.Success rv) - : ropp_correctT rv. - Proof. solve_correct opp_gen_correct. Qed. - - Definition rencode - := let res := Pipeline.BoundsPipeline - false - relax_zrange - (s:=type.Z) - (d:=(type.list type.Z)%ctype) - (prime_bound) - tight_bounds - (fun var - => (encode_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - )%expr in - res. - - Definition rencode_correctT - rv - := Eval hnf in - rexpr_Z1_correctT - prime_bound tight_bounds - (encodemod (Interp rw) s c n (Interp rlen_c)) - rv. - - Lemma rencode_correct - rv (Hrv : rencode = Pipeline.Success rv) - : rencode_correctT rv. - Proof. solve_correct encode_gen_correct. Qed. - - Definition rzero - := let res := Pipeline.BoundsPipelineConst - false - relax_zrange - (t:=(type.list type.Z)%ctype) - tight_bounds - (fun var - => (zero_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - )%expr in - res. - - Definition rzero_correctT - rv - := Eval hnf in - rexpr_01_correctT - tight_bounds - (zeromod (Interp rw) s c n (Interp rlen_c)) - rv. - - Lemma rzero_correct - rv (Hrv : rzero = Pipeline.Success rv) - : rzero_correctT rv. - Proof. solve_correct_const zero_gen_correct. Qed. - - Definition rone - := let res := Pipeline.BoundsPipelineConst - false - relax_zrange - (t:=(type.list type.Z)%ctype) - tight_bounds - (fun var - => (one_gen _) - @ (rw _) - @ (rs _) - @ (rc _) - @ (rn _) - @ (rlen_c _) - )%expr in - res. - - Definition rone_correctT - rv - := Eval hnf in - rexpr_01_correctT - tight_bounds - (onemod (Interp rw) s c n (Interp rlen_c)) - rv. - - Lemma rone_correct - rv (Hrv : rone = Pipeline.Success rv) - : rone_correctT rv. - Proof. solve_correct_const one_gen_correct. Qed. - - Let m : positive := Z.to_pos (s - Associational.eval c). + := BoundsPipeline + (carry_mul_gen + @ rw @ rs @ rc @ rn @ rlen_c @ ridxs @ rlen_idxs) + (loose_bounds, loose_bounds) + tight_bounds. + + Definition rcarry_mul_correct + := BoundsPipeline_correct + (carry_mul_gen + @ rw @ rs @ rc @ rn @ rlen_c @ ridxs @ rlen_idxs)%Expr + (loose_bounds, loose_bounds) + tight_bounds + (carry_mulmod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)). + + Definition rcarry_correct + := BoundsPipeline_correct + (carry_gen + @ rw @ rs @ rc @ rn @ rlen_c @ ridxs @ rlen_idxs)%Expr + loose_bounds + tight_bounds + (carrymod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)). + + Definition rrelax_correct + := BoundsPipeline_correct + (id_gen @ rn)%Expr + tight_bounds + loose_bounds + (expanding_id n). + + Definition radd_correct + := BoundsPipeline_correct + (add_gen @ rw @ rn)%Expr + (tight_bounds, tight_bounds) + loose_bounds + (addmod (Interp rw) n). + + Definition rsub_correct + := BoundsPipeline_correct + (sub_gen @ rw @ rs @ rc @ rn @ rlen_c @ rcoef)%Expr + (tight_bounds, tight_bounds) + loose_bounds + (submod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)). + + Definition ropp_correct + := BoundsPipeline_correct + (opp_gen @ rw @ rs @ rc @ rn @ rlen_c @ rcoef)%Expr + tight_bounds + loose_bounds + (oppmod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)). + + Definition rencode_correct + := BoundsPipeline_correct + (encode_gen @ rw @ rs @ rc @ rn @ rlen_c)%Expr + prime_bound + tight_bounds + (encodemod (Interp rw) s c n (Interp rlen_c)). + + Definition rzero_correct + := BoundsPipelineConst_correct + (zero_gen @ rw @ rs @ rc @ rn @ rlen_c)%Expr + tight_bounds + (zeromod (Interp rw) s c n (Interp rlen_c)). + + Definition rone_correct + := BoundsPipelineConst_correct + (one_gen @ rw @ rs @ rc @ rn @ rlen_c)%Expr + tight_bounds + (onemod (Interp rw) s c n (Interp rlen_c)). + + (* we need to strip off [Hrv : ... = Pipeline.Success rv] *) + Definition rcarry_mul_correctT rv : Prop := type_of_strip_arrow (@rcarry_mul_correct rv). + Definition rcarry_correctT rv : Prop := type_of_strip_arrow (@rcarry_correct rv). + Definition rrelax_correctT rv : Prop := type_of_strip_arrow (@rrelax_correct rv). + Definition radd_correctT rv : Prop := type_of_strip_arrow (@radd_correct rv). + Definition rsub_correctT rv : Prop := type_of_strip_arrow (@rsub_correct rv). + Definition ropp_correctT rv : Prop := type_of_strip_arrow (@ropp_correct rv). + Definition rencode_correctT rv : Prop := type_of_strip_arrow (@rencode_correct rv). + Definition rzero_correctT rv : Prop := type_of_strip_arrow (@rzero_correct rv). + Definition rone_correctT rv : Prop := type_of_strip_arrow (@rone_correct rv). Section make_ring. + Let m : positive := Z.to_pos (s - Associational.eval c). Context (curve_good : check_args (Pipeline.Success tt) = Pipeline.Success tt) {rcarry_mulv} (Hrmulv : rcarry_mul_correctT rcarry_mulv) {rcarryv} (Hrcarryv : rcarry_correctT rcarryv) @@ -6284,40 +5935,29 @@ Section rcarry_mul. (Interp rw) n s c tight_bounds - length_tight_bounds - loose_bounds - m_eq - sc_pos - (Interp rrelaxv) Hrrelaxv - (carry_mulmod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)) - (Interp rcarry_mulv) Hrmulv - (carrymod (Interp rw) s c n (Interp rlen_c) idxs (Interp rlen_idxs)) - (Interp rcarryv) Hrcarryv - (addmod (Interp rw) n) - (Interp raddv) Hraddv - (submod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)) - (Interp rsubv) Hrsubv - (oppmod (Interp rw) s c n (Interp rlen_c) (Interp rcoef)) - (Interp roppv) Hroppv - (zeromod (Interp rw) s c n (Interp rlen_c)) - (Interp rzerov) Hrzerov - (onemod (Interp rw) s c n (Interp rlen_c)) - (Interp ronev) Hronev - (encodemod (Interp rw) s c n (Interp rlen_c)) - (Interp rencodev) Hrencodev. + (Interp rrelaxv) + (Interp rcarry_mulv) + (Interp rcarryv) + (Interp raddv) + (Interp rsubv) + (Interp roppv) + (Interp rzerov) + (Interp ronev) + (Interp rencodev). Theorem Good : GoodT. Proof. pose proof use_curve_good; destruct_head'_and. - apply Ring.Good; + eapply Ring.Good; repeat first [ assumption + | intros; apply eval_carry_mulmod + | intros; apply eval_carrymod + | intros; apply eval_addmod + | intros; apply eval_submod + | intros; apply eval_oppmod + | intros; apply eval_encodemod + | eassumption | progress intros - | rewrite eval_carry_mulmod - | rewrite eval_carrymod - | rewrite eval_addmod - | rewrite eval_submod - | rewrite eval_oppmod - | rewrite eval_encodemod | progress cbv [onemod zeromod] | match goal with | [ |- ?x = ?x ] => reflexivity @@ -6487,39 +6127,39 @@ Module X25519_64. Definition machine_wordsize := 64. Derive base_51_relax - SuchThat (rrelax_correctT n s c base_51_relax) + SuchThat (rrelax_correctT n s c machine_wordsize base_51_relax) As base_51_relax_correct. Proof. Time solve_rrelax machine_wordsize. Time Qed. Derive base_51_carry_mul - SuchThat (rcarry_mul_correctT n s c base_51_carry_mul) + SuchThat (rcarry_mul_correctT n s c machine_wordsize base_51_carry_mul) As base_51_carry_mul_correct. Proof. Time solve_rcarry_mul machine_wordsize. Time Qed. Derive base_51_carry - SuchThat (rcarry_correctT n s c base_51_carry) + SuchThat (rcarry_correctT n s c machine_wordsize base_51_carry) As base_51_carry_correct. Proof. Time solve_rcarry machine_wordsize. Time Qed. Derive base_51_add - SuchThat (radd_correctT n s c base_51_add) + SuchThat (radd_correctT n s c machine_wordsize base_51_add) As base_51_add_correct. Proof. Time solve_radd machine_wordsize. Time Qed. Derive base_51_sub - SuchThat (rsub_correctT n s c base_51_sub) + SuchThat (rsub_correctT n s c machine_wordsize base_51_sub) As base_51_sub_correct. Proof. Time solve_rsub machine_wordsize. Time Qed. Derive base_51_opp - SuchThat (ropp_correctT n s c base_51_opp) + SuchThat (ropp_correctT n s c machine_wordsize base_51_opp) As base_51_opp_correct. Proof. Time solve_ropp machine_wordsize. Time Qed. Derive base_51_encode - SuchThat (rencode_correctT n s c base_51_encode) + SuchThat (rencode_correctT n s c machine_wordsize base_51_encode) As base_51_encode_correct. Proof. Time solve_rencode machine_wordsize. Time Qed. Derive base_51_zero - SuchThat (rzero_correctT n s c base_51_zero) + SuchThat (rzero_correctT n s c machine_wordsize base_51_zero) As base_51_zero_correct. Proof. Time solve_rzero machine_wordsize. Time Qed. Derive base_51_one - SuchThat (rone_correctT n s c base_51_one) + SuchThat (rone_correctT n s c machine_wordsize base_51_one) As base_51_one_correct. Proof. Time solve_rone machine_wordsize. Time Qed. Lemma base_51_curve_good |