aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Saturated.v
diff options
context:
space:
mode:
Diffstat (limited to 'src/Arithmetic/Saturated.v')
-rw-r--r--src/Arithmetic/Saturated.v285
1 files changed, 285 insertions, 0 deletions
diff --git a/src/Arithmetic/Saturated.v b/src/Arithmetic/Saturated.v
new file mode 100644
index 000000000..cb37fb1f9
--- /dev/null
+++ b/src/Arithmetic/Saturated.v
@@ -0,0 +1,285 @@
+Require Import Coq.Init.Nat.
+Require Import Coq.ZArith.ZArith.
+Require Import Coq.Lists.List.
+Local Open Scope Z_scope.
+
+Require Import Crypto.Algebra.Nsatz.
+Require Import Crypto.Arithmetic.Core.
+Require Import Crypto.Util.LetIn Crypto.Util.CPSUtil.
+Require Import Crypto.Util.Tuple Crypto.Util.ListUtil.
+Require Import Crypto.Util.Tactics.BreakMatch.
+Local Notation "A ^ n" := (tuple A n) : type_scope.
+
+(***
+
+Arithmetic on bignums that handles carry bits; this is useful for
+saturated limbs. Compatible with mixed-radix bases.
+
+ ***)
+
+Module Columns.
+ Section Columns.
+ Context {weight : nat->Z}
+ {weight_0 : weight 0%nat = 1}
+ {weight_nonzero : forall i, weight i <> 0}
+ {weight_multiples : forall i, weight (S i) mod weight i = 0}
+ (* add_get_carry takes in a number at which to split output *)
+ {add_get_carry: Z ->Z -> Z -> (Z * Z)}
+ {add_get_carry_correct : forall s x y,
+ fst (add_get_carry s x y) = x + y - s * snd (add_get_carry s x y)}
+ .
+
+ Definition eval {n} (x : (list Z)^n) : Z :=
+ B.Positional.eval weight (Tuple.map sum x).
+
+ Definition eval_from {n} (offset:nat) (x : (list Z)^n) : Z :=
+ B.Positional.eval (fun i => weight (i+offset)) (Tuple.map sum x).
+
+ Lemma eval_from_0 {n} x : @eval_from n 0 x = eval x.
+ Proof using Type. cbv [eval_from eval]. auto using B.Positional.eval_wt_equiv. Qed.
+
+ Lemma eval_from_S {n}: forall i (inp : (list Z)^(S n)),
+ eval_from i inp = eval_from (S i) (tl inp) + weight i * sum (hd inp).
+ Proof using Type.
+ intros; cbv [eval_from].
+ replace inp with (append (hd inp) (tl inp))
+ by (simpl in *; destruct n; destruct inp; reflexivity).
+ rewrite map_append, B.Positional.eval_step, hd_append, tl_append.
+ autorewrite with natsimplify; ring_simplify; rewrite Group.cancel_left.
+ apply B.Positional.eval_wt_equiv; intros; f_equal; omega.
+ Qed.
+
+ (* Sums a list of integers using carry bits.
+ Output : next index, carry, sum
+ *)
+ Fixpoint compact_digit_cps n (digit : list Z) {T} (f:Z * Z->T) :=
+ match digit with
+ | nil => f (0, 0)
+ | x :: nil => f (0, x)
+ | x :: tl =>
+ compact_digit_cps n tl (fun rec =>
+ dlet sum_carry := add_get_carry (weight (S n) / weight n) x (snd rec) in
+ dlet carry' := (fst rec + snd sum_carry)%RT in
+ f (carry', fst sum_carry))
+ end.
+
+ Definition compact_digit n digit := compact_digit_cps n digit id.
+ Lemma compact_digit_id n digit: forall {T} f,
+ @compact_digit_cps n digit T f = f (compact_digit n digit).
+ Proof using Type.
+ induction digit; intros; cbv [compact_digit]; [reflexivity|];
+ simpl compact_digit_cps; break_match; [reflexivity|].
+ rewrite !IHdigit; reflexivity.
+ Qed.
+ Hint Opaque compact_digit : uncps.
+ Hint Rewrite compact_digit_id : uncps.
+
+ Definition compact_step_cps (index:nat) (carry:Z) (digit: list Z)
+ {T} (f:Z * Z->T) :=
+ compact_digit_cps index (carry::digit) f.
+
+ Definition compact_step i c d := compact_step_cps i c d id.
+ Lemma compact_step_id i c d T f :
+ @compact_step_cps i c d T f = f (compact_step i c d).
+ Proof using Type. cbv [compact_step_cps compact_step]; autorewrite with uncps; reflexivity. Qed.
+ Hint Opaque compact_step : uncps.
+ Hint Rewrite compact_step_id : uncps.
+
+ Definition compact_cps {n} (xs : (list Z)^n) {T} (f:Z * Z^n->T) :=
+ mapi_with_cps compact_step_cps 0 xs f.
+
+ Definition compact {n} xs := @compact_cps n xs _ id.
+ Lemma compact_id {n} xs {T} f : @compact_cps n xs T f = f (compact xs).
+ Proof using Type. cbv [compact_cps compact]; autorewrite with uncps; reflexivity. Qed.
+
+ Lemma compact_digit_correct i (xs : list Z) :
+ snd (compact_digit i xs) = sum xs - (weight (S i) / weight i) * (fst (compact_digit i xs)).
+ Proof using add_get_carry_correct weight_0.
+ induction xs; cbv [compact_digit]; simpl compact_digit_cps;
+ cbv [Let_In];
+ repeat match goal with
+ | _ => rewrite add_get_carry_correct
+ | _ => progress (rewrite ?sum_cons, ?sum_nil in * )
+ | _ => progress (autorewrite with uncps push_id in * )
+ | _ => progress (autorewrite with cancel_pair in * )
+ | _ => progress break_match; try discriminate
+ | _ => progress ring_simplify
+ | _ => reflexivity
+ | _ => nsatz
+ end.
+ Qed.
+
+ Definition compact_invariant n i (starter rem:Z) (inp : tuple (list Z) n) (out : tuple Z n) :=
+ B.Positional.eval_from weight i out + weight (i + n) * (rem)
+ = eval_from i inp + weight i*starter.
+
+ Lemma compact_invariant_holds n i starter rem inp out :
+ compact_invariant n (S i) (fst (compact_step_cps i starter (hd inp) id)) rem (tl inp) out ->
+ compact_invariant (S n) i starter rem inp (append (snd (compact_step_cps i starter (hd inp) id)) out).
+ Proof using Type*.
+ cbv [compact_invariant B.Positional.eval_from]; intros.
+ repeat match goal with
+ | _ => rewrite B.Positional.eval_step
+ | _ => rewrite eval_from_S
+ | _ => rewrite sum_cons
+ | _ => rewrite weight_multiples
+ | _ => rewrite Nat.add_succ_l in *
+ | _ => rewrite Nat.add_succ_r in *
+ | _ => (rewrite fst_fst_compact_step in * )
+ | _ => progress ring_simplify
+ | _ => rewrite ZUtil.Z.mul_div_eq_full by apply weight_nonzero
+ | _ => cbv [compact_step_cps] in *;
+ autorewrite with uncps push_id;
+ rewrite compact_digit_correct
+ | _ => progress (autorewrite with natsimplify in * )
+ end.
+ rewrite B.Positional.eval_wt_equiv with (wtb := fun i0 => weight (i0 + S i)) by (intros; f_equal; try omega).
+ nsatz.
+ Qed.
+
+ Lemma compact_invariant_base i rem : compact_invariant 0 i rem rem tt tt.
+ Proof using Type. cbv [compact_invariant]. simpl. repeat (f_equal; try omega). Qed.
+
+ Lemma compact_invariant_end {n} start (input : (list Z)^n):
+ compact_invariant n 0%nat start (fst (mapi_with_cps compact_step_cps start input id)) input (snd (mapi_with_cps compact_step_cps start input id)).
+ Proof using Type*.
+ autorewrite with uncps push_id.
+ apply (mapi_with_invariant _ compact_invariant
+ compact_invariant_holds compact_invariant_base).
+ Qed.
+
+ Lemma eval_compact {n} (xs : tuple (list Z) n) :
+ B.Positional.eval weight (snd (compact xs)) + (weight n * fst (compact xs)) = eval xs.
+ Proof using Type*.
+ pose proof (compact_invariant_end 0 xs) as Hinv.
+ cbv [compact_invariant] in Hinv.
+ simpl in Hinv. autorewrite with zsimplify natsimplify in Hinv.
+ rewrite eval_from_0, B.Positional.eval_from_0 in Hinv; apply Hinv.
+ Qed.
+
+ Definition cons_to_nth_cps {n} i (x:Z) (t:(list Z)^n)
+ {T} (f:(list Z)^n->T) :=
+ @on_tuple_cps _ _ nil (update_nth_cps i (cons x)) n n t _ f.
+
+ Definition cons_to_nth {n} i x t := @cons_to_nth_cps n i x t _ id.
+ Lemma cons_to_nth_id {n} i x t T f :
+ @cons_to_nth_cps n i x t T f = f (cons_to_nth i x t).
+ Proof using Type.
+ cbv [cons_to_nth_cps cons_to_nth].
+ assert (forall xs : list (list Z), length xs = n ->
+ length (update_nth_cps i (cons x) xs id) = n) as Hlen.
+ { intros. autorewrite with uncps push_id distr_length. assumption. }
+ rewrite !on_tuple_cps_correct with (H:=Hlen)
+ by (intros; autorewrite with uncps push_id; reflexivity). reflexivity.
+ Qed.
+ Hint Opaque cons_to_nth : uncps.
+ Hint Rewrite @cons_to_nth_id : uncps.
+
+ Lemma map_sum_update_nth l : forall i x,
+ List.map sum (update_nth i (cons x) l) =
+ update_nth i (Z.add x) (List.map sum l).
+ Proof using Type.
+ induction l; intros; destruct i; simpl; rewrite ?IHl; reflexivity.
+ Qed.
+
+ Lemma cons_to_nth_add_to_nth n i x t :
+ map sum (@cons_to_nth n i x t) = B.Positional.add_to_nth i x (map sum t).
+ Proof using weight.
+ cbv [B.Positional.add_to_nth B.Positional.add_to_nth_cps cons_to_nth cons_to_nth_cps on_tuple_cps].
+ induction n; [simpl; rewrite !update_nth_cps_correct; reflexivity|].
+ specialize (IHn (tl t)). autorewrite with uncps push_id in *.
+ apply to_list_ext. rewrite <-!map_to_list.
+ erewrite !from_list_default_eq, !to_list_from_list.
+ rewrite map_sum_update_nth. reflexivity.
+ Unshelve.
+ distr_length.
+ distr_length.
+ Qed.
+
+ Lemma eval_cons_to_nth n i x t : (i < n)%nat ->
+ eval (@cons_to_nth n i x t) = weight i * x + eval t.
+ Proof using Type.
+ cbv [eval]; intros. rewrite cons_to_nth_add_to_nth.
+ auto using B.Positional.eval_add_to_nth.
+ Qed.
+ Hint Rewrite eval_cons_to_nth using omega : push_basesystem_eval.
+
+ Definition nils n : (list Z)^n := Tuple.repeat nil n.
+
+ Lemma map_sum_nils n : map sum (nils n) = B.Positional.zeros n.
+ Proof using Type.
+ cbv [nils B.Positional.zeros]; induction n; [reflexivity|].
+ change (repeat nil (S n)) with (@nil Z :: repeat nil n).
+ rewrite map_repeat, sum_nil. reflexivity.
+ Qed.
+
+ Lemma eval_nils n : eval (nils n) = 0.
+ Proof using Type. cbv [eval]. rewrite map_sum_nils, B.Positional.eval_zeros. reflexivity. Qed. Hint Rewrite eval_nils : push_basesystem_eval.
+
+ Definition from_associational_cps n (p:list B.limb)
+ {T} (f:(list Z)^n -> T) :=
+ fold_right_cps
+ (fun t st =>
+ B.Positional.place_cps weight t (pred n)
+ (fun p=> cons_to_nth_cps (fst p) (snd p) st id))
+ (nils n) p f.
+
+ Definition from_associational n p := from_associational_cps n p id.
+ Lemma from_associational_id n p T f :
+ @from_associational_cps n p T f = f (from_associational n p).
+ Proof using Type.
+ cbv [from_associational_cps from_associational].
+ autorewrite with uncps push_id; reflexivity.
+ Qed.
+ Hint Opaque from_associational : uncps.
+ Hint Rewrite from_associational_id : uncps.
+
+ Lemma eval_from_associational n p (n_nonzero:n<>0%nat):
+ eval (from_associational n p) = B.Associational.eval p.
+ Proof using weight_0 weight_nonzero.
+ cbv [from_associational_cps from_associational]; induction p;
+ autorewrite with uncps push_id push_basesystem_eval; [reflexivity|].
+ pose proof (B.Positional.weight_place_cps weight weight_0 weight_nonzero a (pred n)).
+ pose proof (B.Positional.place_cps_in_range weight a (pred n)).
+ rewrite Nat.succ_pred in * by assumption. simpl.
+ autorewrite with uncps push_id push_basesystem_eval in *.
+ rewrite eval_cons_to_nth by omega. nsatz.
+ Qed.
+
+ Definition mul_cps {n m} (p q : Z^n) {T} (f : (list Z)^m->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => B.Associational.mul_cps P Q
+ (fun PQ => from_associational_cps m PQ f))).
+
+ Definition add_cps {n} (p q : Z^n) {T} (f : (list Z)^n->T) :=
+ B.Positional.to_associational_cps weight p
+ (fun P => B.Positional.to_associational_cps weight q
+ (fun Q => from_associational_cps n (P++Q) f)).
+
+ End Columns.
+End Columns.
+
+(*
+(* Just some pretty-printing *)
+Local Notation "fst~ a" := (let (x,_) := a in x) (at level 40, only printing).
+Local Notation "snd~ a" := (let (_,y) := a in y) (at level 40, only printing).
+
+(* Simple example : base 10, multiply two bignums and compact them *)
+Definition base10 i := Eval compute in 10^(Z.of_nat i).
+Eval cbv -[runtime_add runtime_mul Let_In] in
+ (fun adc a0 a1 a2 b0 b1 b2 =>
+ Columns.mul_cps (weight := base10) (n:=3) (a2,a1,a0) (b2,b1,b0) (fun ab => Columns.compact (n:=5) (add_get_carry:=adc) (weight:=base10) ab)).
+
+(* More complex example : base 2^56, 8 limbs *)
+Definition base2pow56 i := Eval compute in 2^(56*Z.of_nat i).
+Time Eval cbv -[runtime_add runtime_mul Let_In] in
+ (fun adc a0 a1 a2 a3 a4 a5 a6 a7 b0 b1 b2 b3 b4 b5 b6 b7 =>
+ Columns.mul_cps (weight := base2pow56) (n:=8) (a7,a6,a5,a4,a3,a2,a1,a0) (b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=15) (add_get_carry:=adc) (weight:=base2pow56) ab)). (* Finished transaction in 151.392 secs *)
+
+(* Mixed-radix example : base 2^25.5, 10 limbs *)
+Definition base2pow25p5 i := Eval compute in 2^(25*Z.of_nat i + ((Z.of_nat i + 1) / 2)).
+Time Eval cbv -[runtime_add runtime_mul Let_In] in
+ (fun adc a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 =>
+ Columns.mul_cps (weight := base2pow25p5) (n:=10) (a9,a8,a7,a6,a5,a4,a3,a2,a1,a0) (b9,b8,b7,b6,b5,b4,b3,b2,b1,b0) (fun ab => Columns.compact (n:=19) (add_get_carry:=adc) (weight:=base2pow25p5) ab)). (* Finished transaction in 97.341 secs *)
+*)