diff options
author | 2015-10-25 17:19:46 -0400 | |
---|---|---|
committer | 2015-10-25 17:19:46 -0400 | |
commit | 59b8aacd7f54aa2bc90f2835182ead7d4641d1fe (patch) | |
tree | b0dd53d9ac5dff6f730edeb67f72e7a4fead23ef /src/Galois/BaseSystem.v | |
parent | beaf00fdf971ee0a8955b5358943562af554a20f (diff) | |
parent | d02084f328a2f7253cfee94ffa01bb845827fd2c (diff) |
Merge https://github.com/mit-plv/fiat-crypto into jade
Diffstat (limited to 'src/Galois/BaseSystem.v')
-rw-r--r-- | src/Galois/BaseSystem.v | 290 |
1 files changed, 218 insertions, 72 deletions
diff --git a/src/Galois/BaseSystem.v b/src/Galois/BaseSystem.v index 3ddd92004..dc41740bc 100644 --- a/src/Galois/BaseSystem.v +++ b/src/Galois/BaseSystem.v @@ -1,27 +1,57 @@ Require Import List. Require Import ZArith.ZArith ZArith.Zdiv. + Require Import Omega. + +Lemma nth_error_map : forall A B (f:A->B) xs i y, + nth_error (map f xs) i = Some y -> + exists x, nth_error xs i = Some x /\ f x = y. +Admitted. + +Lemma nth_error_seq : forall start len i, + nth_error (seq start len) i = + if lt_dec i len + then Some (start + i) + else None. +Admitted. + +Lemma nth_error_length_error : forall A (xs:list A) i, nth_error xs i = None -> + i >= length xs. +Admitted. Local Open Scope Z. -Module Type BaseCoefs. - (* lists coefficients of digits and the digits themselves always have the - * LEAST significant position first. *) - Definition coefs : Type := list Z. +Lemma pos_pow_nat_pos : forall x n, + Z.pos x ^ Z.of_nat n > 0. +Admitted. - Parameter base : coefs. - Axiom bs_good : - forall i j, +Module Type BaseCoefs. + (** [BaseCoefs] represent the weights of each digit in a positional number system, with the weight of least significant digit presented first. The following requirements on the base are preconditions for using it with BaseSystem. *) + Parameter base : list Z. + Axiom base_positive : forall b, In b base -> b > 0. (* nonzero would probably work too... *) + Axiom base_good : + forall i j, (i+j < length base)%nat -> let b := nth_default 0 base in let r := (b i * b j) / b (i+j)%nat in b i * b j = r * b (i+j)%nat. End BaseCoefs. Module BaseSystem (Import B:BaseCoefs). + (** [BaseSystem] implements an constrained positional number system. + A wide variety of bases are supported: the base coefficients are not + required to be powers of 2, and it is NOT necessarily the case that + $b_{i+j} = b_j b_j$. Implementations of addition and multiplication are + provided, with focus on near-optimal multiplication performance on + non-trivial but small operands: maybe 10 32-bit integers or so. This + module does not handle carries automatically: if no restrictions are put + on the use of a [BaseSystem], each digit is unbounded. This has nothing + to do with modular arithmetic either. + *) Definition digits : Type := list Z. Definition accumulate p acc := fst p * snd p + acc. - Definition decode bs u := fold_right accumulate 0 (combine u bs). - Hint Unfold decode accumulate. + Definition decode' bs u := fold_right accumulate 0 (combine u bs). + Definition decode := decode' base. + Hint Unfold decode' accumulate. Fixpoint add (us vs:digits) : digits := match us,vs with @@ -29,15 +59,15 @@ Module BaseSystem (Import B:BaseCoefs). | _, nil => us | _, _ => vs end. - Local Infix ".+" := add (at level 50). + Infix ".+" := add (at level 50). - Lemma add_rep : forall bs us vs, decode bs (add us vs) = decode bs us + decode bs vs. + Lemma add_rep : forall bs us vs, decode' bs (add us vs) = decode' bs us + decode' bs vs. Proof. - unfold decode, accumulate. + unfold decode', accumulate. induction bs; destruct us; destruct vs; auto; simpl; try rewrite IHbs; ring. Qed. - Lemma decode_nil : forall bs, decode bs nil = 0. + Lemma decode_nil : forall bs, decode' bs nil = 0. auto. Qed. @@ -49,9 +79,9 @@ Module BaseSystem (Import B:BaseCoefs). end. Definition mul_each u := map (Z.mul u). - Lemma mul_each_rep : forall bs u vs, decode bs (mul_each u vs) = u * decode bs vs. + Lemma mul_each_rep : forall bs u vs, decode' bs (mul_each u vs) = u * decode' bs vs. Proof. - unfold decode, accumulate. + unfold decode', accumulate. induction bs; destruct vs; auto; simpl; try rewrite IHbs; ring. Qed. @@ -60,8 +90,8 @@ Module BaseSystem (Import B:BaseCoefs). (b(i) * b(j)) / b(i+j)%nat. Fixpoint zeros n := match n with O => nil | S n' => 0::zeros n' end. - Lemma zeros_rep : forall bs n, decode bs (zeros n) = 0. - unfold decode, accumulate. + Lemma zeros_rep : forall bs n, decode' bs (zeros n) = 0. + unfold decode', accumulate. induction bs; destruct n; auto; simpl; try rewrite IHbs; ring. Qed. Lemma length_zeros : forall n, length (zeros n) = n. @@ -119,20 +149,20 @@ Module BaseSystem (Import B:BaseCoefs). *) Lemma decode_single : forall n bs x, - decode bs (zeros n ++ x :: nil) = nth_default 0 bs n * x. + decode' bs (zeros n ++ x :: nil) = nth_default 0 bs n * x. Proof. induction n; intros; simpl. - destruct bs; auto; unfold decode, accumulate, nth_default; simpl; ring. + destruct bs; auto; unfold decode', accumulate, nth_default; simpl; ring. destruct bs; simpl; auto. - unfold decode, accumulate, nth_default in *; simpl in *; auto. + unfold decode', accumulate, nth_default in *; simpl in *; auto. Qed. - Lemma peel_decode : forall xs ys x y, decode (x::xs) (y::ys) = x*y + decode xs ys. + Lemma peel_decode : forall xs ys x y, decode' (x::xs) (y::ys) = x*y + decode' xs ys. intros. - unfold decode, accumulate, nth_default in *; simpl in *; ring_simplify; auto. + unfold decode', accumulate, nth_default in *; simpl in *; ring_simplify; auto. Qed. - Lemma decode_highzeros : forall xs bs n, decode bs (xs ++ zeros n) = decode bs xs. + Lemma decode_highzeros : forall xs bs n, decode' bs (xs ++ zeros n) = decode' bs xs. induction xs; intros; simpl; try rewrite zeros_rep; auto. destruct bs; simpl; auto. repeat (rewrite peel_decode). @@ -140,45 +170,53 @@ Module BaseSystem (Import B:BaseCoefs). Qed. Lemma mul_bi_single : forall m n x, - decode base (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n. + (n + m < length base)%nat -> + decode (mul_bi n (zeros m ++ x :: nil)) = nth_default 0 base m * x * nth_default 0 base n. Proof. - unfold mul_bi. - destruct m; simpl; ssimpl_list; simpl; intros. - rewrite decode_single. - unfold crosscoef; simpl. - rewrite plus_0_r. - ring_simplify. - replace (nth_default 0 base n * nth_default 0 base 0) with (nth_default 0 base 0 * nth_default 0 base n) by ring. - SearchAbout Z.div. - rewrite Z_div_mult; try ring. - - assert (nth_default 0 base n > 0) by admit; auto. - - intros; simpl; ssimpl_list; simpl. - replace (mul_bi' n (rev (zeros m) ++ 0 :: nil)) with (zeros (S m)) by admit. - intros; simpl; ssimpl_list; simpl. - rewrite length_zeros. - rewrite app_cons_app_app. - rewrite rev_zeros. - intros; simpl; ssimpl_list; simpl. - rewrite zeros_app0. - rewrite app_assoc. - rewrite app_zeros_zeros. - rewrite decode_single. - unfold crosscoef; simpl; ring_simplify. - rewrite NPeano.Nat.add_1_r. - rewrite bs_good. - rewrite Z_div_mult. - rewrite <- Z.mul_assoc. - rewrite <- Z.mul_comm. - rewrite <- Z.mul_assoc. - rewrite <- Z.mul_assoc. - destruct (Z.eq_dec x 0); subst; try ring. - rewrite Z.mul_cancel_l by auto. - rewrite <- bs_good. - ring. + unfold mul_bi, decode. + destruct m; simpl; ssimpl_list; simpl; intros. { + rewrite decode_single. + unfold crosscoef; simpl. + rewrite plus_0_r. + ring_simplify. + replace (nth_default 0 base n * nth_default 0 base 0) with (nth_default 0 base 0 * nth_default 0 base n) by ring. + SearchAbout Z.div. + rewrite Z_div_mult; try ring. + + apply base_positive. + rewrite nth_default_eq. + apply nth_In. + rewrite plus_0_r in *. + auto. + } { + simpl; ssimpl_list; simpl. + replace (mul_bi' n (rev (zeros m) ++ 0 :: nil)) with (zeros (S m)) by admit. + intros; simpl; ssimpl_list; simpl. + rewrite length_zeros. + rewrite app_cons_app_app. + rewrite rev_zeros. + intros; simpl; ssimpl_list; simpl. + rewrite zeros_app0. + rewrite app_assoc. + rewrite app_zeros_zeros. + rewrite decode_single. + unfold crosscoef; simpl; ring_simplify. + rewrite NPeano.Nat.add_1_r. + rewrite base_good by auto. + rewrite Z_div_mult. + rewrite <- Z.mul_assoc. + rewrite <- Z.mul_comm. + rewrite <- Z.mul_assoc. + rewrite <- Z.mul_assoc. + destruct (Z.eq_dec x 0); subst; try ring. + rewrite Z.mul_cancel_l by auto. + rewrite <- base_good by auto. + ring. - assert (nth_default 0 base (n + S m) > 0) by admit; auto. + apply base_positive. + rewrite nth_default_eq. + apply nth_In; auto. + } Qed. Lemma set_higher' : forall vs x, vs++x::nil = vs .+ (zeros (length vs) ++ x :: nil). @@ -187,7 +225,7 @@ Module BaseSystem (Import B:BaseCoefs). Qed. Lemma set_higher : forall bs vs x, - decode bs (vs++x::nil) = decode bs vs + nth_default 0 bs (length vs) * x. + decode' bs (vs++x::nil) = decode' bs vs + nth_default 0 bs (length vs) * x. Proof. intros. rewrite set_higher'. @@ -298,17 +336,28 @@ Module BaseSystem (Import B:BaseCoefs). rewrite <- rev_add; auto. Qed. - Lemma mul_bi_rep : forall i vs, decode base (mul_bi i vs) = decode base vs * nth_default 0 base i. + Lemma mul_bi_rep : forall i vs, + (i + length vs < length base)%nat -> + decode (mul_bi i vs) = decode vs * nth_default 0 base i. + Proof. + unfold decode. induction vs using rev_ind; intros; simpl. { - unfold mul_bi. + unfold mul_bi, decode. ssimpl_list; rewrite zeros_rep; simpl. - unfold decode; simpl. + unfold decode'; simpl. ring. } { + assert (i + length vs < length base)%nat as inbounds. { + rewrite app_length in *; simpl in *. + rewrite NPeano.Nat.add_1_r, <- plus_n_Sm in *. + etransitivity; eauto. + } + rewrite set_higher. ring_simplify. - rewrite <- IHvs; clear IHvs. - rewrite <- mul_bi_single. + rewrite <- IHvs by auto; clear IHvs. + simpl in *. + rewrite <- mul_bi_single by auto. rewrite <- add_rep. rewrite <- mul_bi_add. rewrite set_higher'. @@ -324,20 +373,117 @@ Module BaseSystem (Import B:BaseCoefs). | _ => nil end. Definition mul us := mul' (rev us). - Local Infix "#*" := mul (at level 40). + Infix "#*" := mul (at level 40). - Lemma mul'_rep : forall us vs, decode base (mul' (rev us) vs) = decode base us * decode base vs. + Lemma mul'_rep : forall us vs, + (length us + length vs <= length base)%nat -> + decode (mul' (rev us) vs) = decode us * decode vs. + Proof. + unfold decode. induction us using rev_ind; intros; simpl; try apply decode_nil. + + assert (length us + length vs < length base)%nat as inbounds. { + rewrite app_length in *; simpl in *. + rewrite plus_comm in *. + rewrite NPeano.Nat.add_1_r, <- plus_n_Sm in *. + auto. + } + ssimpl_list. rewrite add_rep. - rewrite IHus; clear IHus. + rewrite IHus by (rewrite le_trans; eauto); clear IHus. rewrite set_higher. rewrite mul_each_rep. - rewrite mul_bi_rep. + rewrite mul_bi_rep by auto; unfold decode. ring. Qed. - Lemma mul_rep : forall us vs, decode base (us #* vs) = decode base us * decode base vs. - apply mul'_rep. + Lemma mul_rep : forall us vs, + (length us + length vs <= length base)%nat -> + decode (us #* vs) = decode us * decode vs. + Proof. + exact mul'_rep. Qed. End BaseSystem. + +Module Type PolynomialBaseParams. + Parameter b1 : positive. (* the value at which the polynomial is evaluated *) + Parameter baseLength : nat. (* 1 + degree of the polynomial *) +End PolynomialBaseParams. + +Module PolynomialBaseCoefs (Import P:PolynomialBaseParams) <: BaseCoefs. + (** PolynomialBaseCoeffs generates base vectors for [BaseSystem] using the extra assumption that $b_{i+j} = b_j b_j$. *) + Definition bi i := (Zpos b1)^(Z.of_nat i). + Definition base := map bi (seq 0 baseLength). + + Lemma base_positive : forall b, In b base -> b > 0. + unfold base. + intros until 0; intro H. + rewrite in_map_iff in *. + destruct H; destruct H. + subst. + apply pos_pow_nat_pos. + Qed. + + Lemma base_good: + forall i j, (i + j < length base)%nat -> + let b := nth_default 0 base in + let r := (b i * b j) / b (i+j)%nat in + b i * b j = r * b (i+j)%nat. + Proof. + unfold base, nth_default. + intros; repeat progress (match goal with + | [ |- context[match nth_error ?xs ?i with Some _ => _ | None => _ end ] ] => case_eq (nth_error xs i); intros + | [ H: nth_error (map _ _) _ = Some _ |- _ ] => destruct (nth_error_map _ _ _ _ _ _ H); clear H + | [ H: _ /\ _ |- _ ] => destruct H + | [ H: nth_error (seq _ _) _ = Some _ |- _ ] => rewrite nth_error_seq in H + | [ H: context[if lt_dec ?a ?b then _ else _] |- _ ] => destruct (lt_dec a b) + | [ H: Some _ = Some _ |- _ ] => injection H; clear H; intros; subst + | [ H: None = Some _ |- _ ] => inversion H + | [ H: Some _ = None |- _ ] => inversion H + | [H: nth_error _ _ = None |- _ ] => specialize (nth_error_length_error _ _ _ H); intro; clear H + end); autorewrite with list in *; try omega. + + clear. + unfold bi. + rewrite Nat2Z.inj_add, Zpower_exp by + (replace 0 with (Z.of_nat 0) by auto; rewrite <- Nat2Z.inj_ge; omega). + rewrite Z_div_same_full; try ring. + rewrite <- Z.neq_mul_0. + split; apply Z.pow_nonzero; try apply Zle_0_nat; try solve [intro H; inversion H]. + Qed. +End PolynomialBaseCoefs. + +Module BasePoly2Degree32Params <: PolynomialBaseParams. + Definition b1 := 2%positive. + Definition baseLength := 32%nat. +End BasePoly2Degree32Params. + +Import ListNotations. + +Module BaseSystemExample. + Module BasePoly2Degree32Coefs := PolynomialBaseCoefs BasePoly2Degree32Params. + Module BasePoly2Degree32 := BaseSystem BasePoly2Degree32Coefs. + Import BasePoly2Degree32. + + Example three_times_two : [1;1;0] #* [0;1;0] = [0;1;1;0;0]. + compute; reflexivity. + Qed. + + (* python -c "e = lambda x: '['+''.join(reversed(bin(x)[2:])).replace('1','1;').replace('0','0;')[:-1]+']'; print(e(19259)); print(e(41781))" *) + Definition a := [1;1;0;1;1;1;0;0;1;1;0;1;0;0;1]. + Definition b := [1;0;1;0;1;1;0;0;1;1;0;0;0;1;0;1]. + Example da : decode a = 19259. + compute. reflexivity. + Qed. + Example db : decode b = 41781. + compute. reflexivity. + Qed. + Example encoded_ab : + a #*b =[1;1;1;2;2;4;2;2;4;5;3;3;3;6;4;2;5;3;4;3;2;1;2;2;2;0;1;1;0;1]. + compute. reflexivity. + Qed. + Example dab : decode (a #* b) = 804660279. + compute. reflexivity. + Qed. +End BaseSystemExample. |