diff options
Diffstat (limited to 'src/Arithmetic/Core.v')
-rw-r--r-- | src/Arithmetic/Core.v | 73 |
1 files changed, 48 insertions, 25 deletions
diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v index 430e1c19a..f2d9ee00b 100644 --- a/src/Arithmetic/Core.v +++ b/src/Arithmetic/Core.v @@ -585,9 +585,10 @@ Module B. Context {T : Type}. Fixpoint place_cps (t:limb) (i:nat) (f:nat * Z->T) := - if dec (fst t mod weight i = 0) + Z.eqb_cps (fst t mod weight i) 0 (fun eqb => + if eqb then f (i, let c := fst t / weight i in (c * snd t)%RT) - else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end. + else match i with S i' => place_cps t i' f | O => f (O, fst t * snd t)%RT end). End place_cps. Definition place t i := place_cps t i id. @@ -599,12 +600,13 @@ Module B. Lemma place_cps_in_range (t:limb) (n:nat) : (fst (place_cps t n id) < S n)%nat. - Proof using Type. induction n; simpl; break_match; simpl; omega. Qed. + Proof using Type. induction n; simpl; cbv [Z.eqb_cps]; break_match; simpl; omega. Qed. Lemma weight_place_cps t i : weight (fst (place_cps t i id)) * snd (place_cps t i id) = fst t * snd t. Proof using Type*. - induction i; cbv [id]; simpl place_cps; break_match; + induction i; cbv [id]; simpl place_cps; cbv [Z.eqb_cps]; break_match; + Z.ltb_to_lt; autorewrite with cancel_pair; try match goal with [H:_|-_] => apply Z_div_exact_full_2 in H end; nsatz || auto. @@ -962,38 +964,59 @@ End B. (* Modulo and div that do shifts if possible, otherwise normal mod/div *) Section DivMod. - Definition modulo (a b : Z) : Z := - if dec (2 ^ (Z.log2 b) = b) - then let x := (Z.ones (Z.log2 b)) in (a &' x)%RT - else Z.modulo a b. - - Definition div (a b : Z) : Z := - if dec (2 ^ (Z.log2 b) = b) - then let x := Z.log2 b in (a >> x)%RT - else Z.div a b. - - Lemma div_correct a b : div a b = Z.div a b. + Definition modulo_cps {T} (a b : Z) (f : Z -> T) : T := + Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => + if eqb + then let x := (Z.ones (Z.log2 b)) in f (a &' x)%RT + else f (Z.modulo a b)). + + Definition div_cps {T} (a b : Z) (f : Z -> T) : T := + Z.eqb_cps (2 ^ (Z.log2 b)) b (fun eqb => + if eqb + then let x := Z.log2 b in f ((a >> x)%RT) + else f (Z.div a b)). + + Definition modulo (a b : Z) : Z := modulo_cps a b id. + Definition div (a b : Z) : Z := div_cps a b id. + + Lemma modulo_id {T} a b f + : @modulo_cps T a b f = f (modulo a b). + Proof. cbv [modulo_cps modulo]; autorewrite with uncps; break_match; reflexivity. Qed. + Hint Opaque modulo : uncps. + Hint Rewrite @modulo_id : uncps. + + Lemma div_id {T} a b f + : @div_cps T a b f = f (div a b). + Proof. cbv [div_cps div]; autorewrite with uncps; break_match; reflexivity. Qed. + Hint Opaque div : uncps. + Hint Rewrite @div_id : uncps. + + Lemma div_cps_correct {T} a b f : @div_cps T a b f = f (Z.div a b). Proof. - cbv [div]; intros. break_match; try reflexivity. + cbv [div_cps Z.eqb_cps]; intros. break_match; try reflexivity. rewrite Z.shiftr_div_pow2 by apply Z.log2_nonneg. - congruence. + Z.ltb_to_lt; congruence. Qed. - Lemma modulo_correct a b : modulo a b = Z.modulo a b. + Lemma modulo_cps_correct {T} a b f : @modulo_cps T a b f = f (Z.modulo a b). Proof. - cbv [modulo]; intros. break_match; try reflexivity. + cbv [modulo_cps Z.eqb_cps]; intros. break_match; try reflexivity. rewrite Z.land_ones by apply Z.log2_nonneg. - congruence. + Z.ltb_to_lt; congruence. Qed. + Definition div_correct a b : div a b = Z.div a b := div_cps_correct a b id. + Definition modulo_correct a b : modulo a b = Z.modulo a b := modulo_cps_correct a b id. + Lemma div_mod a b (H:b <> 0) : a = b * div a b + modulo a b. Proof. - cbv [div modulo]; intros. break_match; auto using Z.div_mod. - rewrite Z.land_ones, Z.shiftr_div_pow2 by apply Z.log2_nonneg. - pose proof (Z.div_mod a b H). congruence. + rewrite div_correct, modulo_correct; auto using Z.div_mod. Qed. End DivMod. +Hint Opaque div modulo : uncps. +Hint Rewrite @div_id @modulo_id : uncps. + Import B. Create HintDb basesystem_partial_evaluation_unfolder. @@ -1045,7 +1068,7 @@ Hint Unfold Positional.eval_from Positional.select_cps Positional.select - modulo div + modulo div modulo_cps div_cps id_tuple_with_alt id_tuple'_with_alt Z.add_get_carry_full Z.add_get_carry_full_cps : basesystem_partial_evaluation_unfolder. @@ -1055,7 +1078,7 @@ Hint Unfold CPSUtil.Tuple.mapi_with_cps CPSUtil.Tuple.mapi_with'_cps CPSUtil.flat_map_cps CPSUtil.on_tuple_cps CPSUtil.fold_right_cps2 Decidable.dec Decidable.dec_eq_Z id_tuple_with_alt id_tuple'_with_alt - Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps + Z.add_get_carry_full Z.add_get_carry_full_cps Z.mul_split Z.mul_split_cps Z.mul_split_cps' : basesystem_partial_evaluation_unfolder. |