aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Core.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Arithmetic/Core.v')
-rw-r--r--src/Arithmetic/Core.v73
1 files changed, 48 insertions, 25 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v
index 430e1c19a..f2d9ee00b 100644
--- a/src/Arithmetic/Core.v
+++ b/src/Arithmetic/Core.v
@@ -585,9 +585,10 @@ Module B.
Context {T : Type}.
Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) :=
- if dec (fst t mod weight i = 0)
+ Z.eqb_cps (fst t mod weight i) 0 (fun eqb =>
+ if eqb
then f (i, let c := fst t / weight i in (c * snd t)%RT)
- else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end.
+ else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end).
End place_cps.
Definition place t i := place_cps t i id.
@@ -599,12 +600,13 @@ Module B.
Lemma place_cps_in_range (t:limb) (n:nat)
: (fst (place_cps t n id) < S n)%nat.
- Proof using Type. induction n; simpl; break_match; simpl; omega. Qed.
+ Proof using Type. induction n; simpl; cbv [Z.eqb_cps]; break_match; simpl; omega. Qed.
Lemma weight_place_cps t i
: weight (fst (place_cps t i id)) * snd (place_cps t i id)
= fst t * snd t.
Proof using Type*.
- induction i; cbv [id]; simpl place_cps; break_match;
+ induction i; cbv [id]; simpl place_cps; cbv [Z.eqb_cps]; break_match;
+ Z.ltb_to_lt;
autorewrite with cancel_pair;
try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end;
nsatz || auto.
@@ -962,38 +964,59 @@ End B.
(* Modulo and div that do shifts if possible, otherwise normal mod/div *)
Section DivMod.
- Definition modulo (a b : Z) : Z :=
- if dec (2 ^ (Z.log2 b) = b)
- then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT
- else Z.modulo a b.
-
- Definition div (a b : Z) : Z :=
- if dec (2 ^ (Z.log2 b) = b)
- then let x := Z.log2 b in (a >> x)%RT
- else Z.div a b.
-
- Lemma div_correct a b : div a b = Z.div a b.
+ Definition modulo_cps {T} (a b : Z) (f : Z -> T) : T :=
+ Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb =>
+ if eqb
+ then let x := (Z.ones (Z.log2 b)) in f (a &' x)%RT
+ else f (Z.modulo a b)).
+
+ Definition div_cps {T} (a b : Z) (f : Z -> T) : T :=
+ Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb =>
+ if eqb
+ then let x := Z.log2 b in f ((a >> x)%RT)
+ else f (Z.div a b)).
+
+ Definition modulo (a b : Z) : Z := modulo_cps a b id.
+ Definition div (a b : Z) : Z := div_cps a b id.
+
+ Lemma modulo_id {T} a b f
+ : @modulo_cps T a b f = f (modulo a b).
+ Proof. cbv [modulo_cps modulo]; autorewrite with uncps; break_match; reflexivity. Qed.
+ Hint Opaque modulo : uncps.
+ Hint Rewrite @modulo_id : uncps.
+
+ Lemma div_id {T} a b f
+ : @div_cps T a b f = f (div a b).
+ Proof. cbv [div_cps div]; autorewrite with uncps; break_match; reflexivity. Qed.
+ Hint Opaque div : uncps.
+ Hint Rewrite @div_id : uncps.
+
+ Lemma div_cps_correct {T} a b f : @div_cps T a b f = f (Z.div a b).
Proof.
- cbv [div]; intros. break_match; try reflexivity.
+ cbv [div_cps Z.eqb_cps]; intros. break_match; try reflexivity.
rewrite Z.shiftr_div_pow2 by apply Z.log2_nonneg.
- congruence.
+ Z.ltb_to_lt; congruence.
Qed.
- Lemma modulo_correct a b : modulo a b = Z.modulo a b.
+ Lemma modulo_cps_correct {T} a b f : @modulo_cps T a b f = f (Z.modulo a b).
Proof.
- cbv [modulo]; intros. break_match; try reflexivity.
+ cbv [modulo_cps Z.eqb_cps]; intros. break_match; try reflexivity.
rewrite Z.land_ones by apply Z.log2_nonneg.
- congruence.
+ Z.ltb_to_lt; congruence.
Qed.
+ Definition div_correct a b : div a b = Z.div a b := div_cps_correct a b id.
+ Definition modulo_correct a b : modulo a b = Z.modulo a b := modulo_cps_correct a b id.
+
Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b.
Proof.
- cbv [div modulo]; intros. break_match; auto using Z.div_mod.
- rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg.
- pose proof (Z.div_mod a b H). congruence.
+ rewrite div_correct, modulo_correct; auto using Z.div_mod.
Qed.
End DivMod.
+Hint Opaque div modulo : uncps.
+Hint Rewrite @div_id @modulo_id : uncps.
+
Import B.
Create HintDb basesystem_partial_evaluation_unfolder.
@@ -1045,7 +1068,7 @@ Hint Unfold
Positional.eval_from
Positional.select_cps
Positional.select
- modulo div
+ modulo div modulo_cps div_cps
id_tuple_with_alt id_tuple'_with_alt
Z.add_get_carry_full Z.add_get_carry_full_cps
: basesystem_partial_evaluation_unfolder.
@@ -1055,7 +1078,7 @@ Hint Unfold
CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2
Decidable.dec Decidable.dec_eq_Z
id_tuple_with_alt id_tuple'_with_alt
- Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps
+ Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps Z.mul_split_cps'
: basesystem_partial_evaluation_unfolder.