aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic
diff options
context:
space:
mode:
Diffstat (limited to 'src/Arithmetic')
-rw-r--r--src/Arithmetic/Core.v81
1 files changed, 57 insertions, 24 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
index 929aa2260..a33bd394d 100644
--- a/src/Arithmetic/Core.v
+++ b/src/Arithmetic/Core.v
@@ -438,23 +438,32 @@ Module B.
Qed. Hint Rewrite eval_negate_snd : push_basesystem_eval.
Section Carries.
- Context {modulo div:Z->Z->Z}.
+ Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}.
+ Let modulo x y := modulo_cps _ x y id.
+ Let div x y := div_cps _ x y id.
+ Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)}
+ {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}.
Context {div_mod : forall a b:Z, b <> 0 ->
a = b * (div a b) + modulo a b}.
+ Hint Rewrite modulo_cps_id div_cps_id : uncps.
Definition carryterm_cps (w fw:Z) (t:limb) {T} (f:list limb->T) :=
if dec (fst t = w)
then dlet t2 := snd t in
- dlet d2 := div t2 fw in
- dlet m2 := modulo t2 fw in
- f ((w*fw, d2) :: (w, m2) :: @nil limb)
+ div_cps _ t2 fw (fun d2 =>
+ modulo_cps _ t2 fw (fun m2 =>
+ dlet d2 := d2 in
+ dlet m2 := m2 in
+ f ((w*fw, d2) :: (w, m2) :: @nil limb)))
else f [t].
Definition carryterm w fw t := carryterm_cps w fw t id.
Lemma carryterm_cps_id w fw t {T} f :
@carryterm_cps w fw t T f
= f (@carryterm w fw t).
- Proof using Type. cbv [carryterm_cps carryterm Let_In]; prove_id. Qed.
+ Proof using div_cps_id modulo_cps_id.
+ cbv [carryterm_cps carryterm Let_In]; prove_id.
+ Qed.
Hint Opaque carryterm : uncps.
Hint Rewrite carryterm_cps_id : uncps.
@@ -473,7 +482,9 @@ Module B.
Definition carry w fw p := carry_cps w fw p id.
Lemma carry_cps_id w fw p {T} f:
@carry_cps w fw p T f = f (carry w fw p).
- Proof using Type. cbv [carry_cps carry]; prove_id. Qed.
+ Proof using div_cps_id modulo_cps_id.
+ cbv [carry_cps carry]; prove_id.
+ Qed.
Hint Opaque carry : uncps.
Hint Rewrite carry_cps_id : uncps.
@@ -484,12 +495,19 @@ Module B.
End Carries.
End Associational.
+
+ Ltac div_mod_cps_t :=
+ intros; autorewrite with uncps push_id; try reflexivity.
+
Hint Rewrite
- @Associational.carry_cps_id
- @Associational.carryterm_cps_id
@Associational.reduce_cps_id
@Associational.split_cps_id
@Associational.mul_cps_id : uncps.
+ Hint Rewrite
+ @Associational.carry_cps_id
+ @Associational.carryterm_cps_id
+ using div_mod_cps_t : uncps.
+
Module Positional.
Section Positional.
@@ -644,14 +662,20 @@ Module B.
: push_basesystem_eval.
Section Carries.
- Context {modulo div : Z->Z->Z}.
+ Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}.
+ Let modulo x y := modulo_cps _ x y id.
+ Let div x y := div_cps _ x y id.
+ Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)}
+ {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}.
Context {div_mod : forall a b:Z, b <> 0 ->
- a = b * (div a b) + modulo a b}.
+ a = b * (div a b) + modulo a b}.
+ Hint Rewrite modulo_cps_id div_cps_id : uncps.
+
Definition carry_cps {n m} (index:nat) (p:tuple Z n)
{T} (f:tuple Z m->T) :=
to_associational_cps p
(fun P => @Associational.carry_cps
- modulo div
+ modulo_cps div_cps
(weight index)
(weight (S index) / weight index)
P T
@@ -690,7 +714,9 @@ Module B.
Definition chained_carries {n} p idxs := @chained_carries_cps n p idxs _ id.
Lemma chained_carries_id {n} p idxs : forall {T} f,
@chained_carries_cps n p idxs T f = f (chained_carries p idxs).
- Proof using Type. cbv [chained_carries_cps chained_carries]; prove_id. Qed.
+ Proof using modulo_cps_id div_cps_id.
+ cbv [chained_carries_cps chained_carries]; prove_id.
+ Qed.
Hint Opaque chained_carries : uncps.
Hint Rewrite @chained_carries_id : uncps.
@@ -738,10 +764,10 @@ Module B.
(fun P => Associational.reduce_cps s c P
(fun R => from_associational_cps n R f)).
- Definition carry_reduce_cps {n div modulo}
+ Definition carry_reduce_cps {n div_cps modulo_cps}
(s:Z) (c:list limb) (p : tuple Z n)
{T} (f: tuple Z n ->T) :=
- carry_cps (div:=div) (modulo:=modulo) (n:=n) (m:=S n) (pred n) p
+ carry_cps (div_cps:=div_cps) (modulo_cps:=modulo_cps) (n:=n) (m:=S n) (pred n) p
(fun r => reduce_cps (m:=S n) (n:=n) s c r f).
Definition negate_snd_cps {n} (p : tuple Z n)
@@ -820,18 +846,23 @@ Module B.
Section F.
Context {sz:nat} {sz_nonzero : sz<>0%nat} {m :positive}.
Context (weight_divides : forall i : nat, weight (S i) / weight i <> 0).
- Context {modulo div:Z->Z->Z}
- {div_mod : forall a b:Z, b <> 0 ->
+ Context {modulo_cps div_cps:forall {R},Z->Z->(Z->R)->R}.
+ Let modulo x y := modulo_cps _ x y id.
+ Let div x y := div_cps _ x y id.
+ Context {modulo_cps_id : forall R x y f, modulo_cps R x y f = f (modulo x y)}
+ {div_cps_id : forall R x y f, div_cps R x y f = f (div x y)}.
+ Context {div_mod : forall a b:Z, b <> 0 ->
a = b * (div a b) + modulo a b}.
+ Hint Rewrite modulo_cps_id div_cps_id : uncps.
Definition Fencode (x : F m) : tuple Z sz :=
- encode (div:=div) (modulo:=modulo) (F.to_Z x).
+ encode (div_cps:=div_cps) (modulo_cps:=modulo_cps) (F.to_Z x).
Definition Fdecode (x : tuple Z sz) : F m := F.of_Z m (eval x).
Lemma Fdecode_Fencode_id x : Fdecode (Fencode x) = x.
- Proof using div_mod sz_nonzero weight_0 weight_divides weight_nonzero.
- cbv [Fdecode Fencode]; rewrite @eval_encode by auto.
+ Proof using div_mod sz_nonzero weight_0 weight_divides weight_nonzero div_cps_id modulo_cps_id.
+ cbv [Fdecode Fencode]; rewrite @eval_encode by eauto.
apply F.of_Z_to_Z.
Qed.
@@ -935,21 +966,23 @@ Module B.
Positional.opp_cps
.
Hint Rewrite
- @Associational.carry_cps_id
- @Associational.carryterm_cps_id
@Associational.reduce_cps_id
@Associational.split_cps_id
@Associational.mul_cps_id
- @Positional.carry_cps_id
@Positional.from_associational_cps_id
@Positional.place_cps_id
@Positional.add_to_nth_cps_id
@Positional.to_associational_cps_id
- @Positional.chained_carries_id
@Positional.sub_id
@Positional.select_id
: uncps.
Hint Rewrite
+ @Associational.carry_cps_id
+ @Associational.carryterm_cps_id
+ @Positional.carry_cps_id
+ @Positional.chained_carries_id
+ using div_mod_cps_t : uncps.
+ Hint Rewrite
@Associational.eval_mul
@Positional.eval_single
@Positional.eval_unit
@@ -966,7 +999,7 @@ Module B.
@Positional.eval_chained_carries
@Positional.eval_sub
@Positional.eval_select
- using (assumption || vm_decide) : push_basesystem_eval.
+ using (assumption || (div_mod_cps_t; auto) || vm_decide) : push_basesystem_eval.
End B.
(* Modulo and div that do shifts if possible, otherwise normal mod/div *)