diff options
author | jadep <jade.philipoom@gmail.com> | 2016-07-11 12:00:49 -0400 |
---|---|---|
committer | jadep <jade.philipoom@gmail.com> | 2016-07-11 12:00:49 -0400 |
commit | bb38344557cddbc64eac0eb5b174d54c0507e08a (patch) | |
tree | da2d447b51b886ab706f21963849f1052accac0e | |
parent | 9a7c5b2a18ce47dbfc2bc3513f36856001499d98 (diff) | |
parent | 762f2a27f9d237050ea5ab342f6e893ab4b4ac25 (diff) |
Merge of fixedlength and master
40 files changed, 3197 insertions, 1421 deletions
diff --git a/.gitignore b/.gitignore index b8cf1d4bc..b22a815ee 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ .#* Makefile.bak Makefile.coq +Makefile.coq.bak csdp.cache lia.cache nlia.cache diff --git a/.travis.yml b/.travis.yml index afbd96080..5144a240d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,8 @@ sudo: required matrix: include: - dist: trusty + env: COQ_VERSION="8.5pl2" COQ_PACKAGE="coq-8.5pl2" COQPRIME="coqprime" PPA="ppa:jgross-h/many-coq-versions" + - dist: trusty env: COQ_VERSION="8.5pl1" COQ_PACKAGE="coq-8.5pl1" COQPRIME="coqprime" PPA="ppa:jgross-h/many-coq-versions" - dist: trusty env: COQ_VERSION="8.5" COQ_PACKAGE="coq-8.5" COQPRIME="coqprime" PPA="ppa:jgross-h/many-coq-versions" diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..31d71c211 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,16 @@ +# This is the official list of fiat-crypto authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. + +# Names should be added to this file as one of +# Organization's name +# Individual's name <submission email address> +# Individual's name <submission email address> <email2> <emailN> +# See CONTRIBUTORS for the meaning of multiple email addresses. + +# Please keep the list sorted. + +Andres Erbsen <andreser@mit.edu> +Google Inc. +Jade Philipoom <jadep@mit.edu> <jade.philipoom@gmail.com> +Massachusetts Institute of Technology diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 000000000..905edafd1 --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,28 @@ +# This is the official list of people have contributed code to the +# fiat-crypto repository. +# +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS, because Google holds the copyright. +# +# When adding J Random Contributor's name to this file, +# either J's name or J's organization's name should be +# added to the AUTHORS file, depending on who holds the copyright. +# +# Names should be added to this file like so: +# Individual's name <submission email address> +# Individual's name <submission email address> <email2> <emailN> +# +# An entry with multiple email addresses specifies that the +# first address should be used in the submit logs and +# that the other addresses should be recognized as the +# same person. + +# Please keep the list sorted. + +Adam Chlipala <adamc@csail.mit.edu> <adam@chlipala.net> +Andres Erbsen <andreser@mit.edu> +Daniel Ziegler <dmz@mit.edu> +Jade Philipoom <jadep@mit.edu> <jade.philipoom@gmail.com> +Jason Gross <jgross@mit.edu> <jagro@google.com> <jasongross9@gmail.com> +Robert Sloan <rsloan@mit.edu> <varomodt@gmail.com> <rsloan@sumologic.com> @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2015 Programming Languages and Verification Group at MIT CSAIL +Copyright (c) 2015-2016 the fiat-crypto authors (see the AUTHORS file). Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -1,19 +1,12 @@ MOD_NAME := Crypto SRC_DIR := src -.PHONY: coq clean install coqprime-8.4 coqprime-8.5 coqprime update-_CoqProject +.PHONY: coq clean update-_CoqProject cleanall install \ + install-coqprime install-coqprime-8.4 install-coqprime-8.5 \ + clean-coqprime clean-coqprime-8.4 clean-coqprime-8.5 \ + coqprime coqprime-8.4 coqprime-8.5 .DEFAULT_GOAL := coq -VERBOSE = 0 - -SILENCE_COQC_0 = @echo "COQC $<"; # -SILENCE_COQC_1 = -SILENCE_COQC = $(SILENCE_COQC_$(VERBOSE)) - -SILENCE_COQDEP_0 = @echo "COQDEP $<"; # -SILENCE_COQDEP_1 = -SILENCE_COQDEP = $(SILENCE_COQDEP_$(VERBOSE)) - SORT_COQPROJECT = sed 's,[^/]*/,~&,g' | env LC_COLLATE=C sort | sed 's,~,,g' update-_CoqProject:: @@ -27,8 +20,12 @@ COQ_VERSION := $(firstword $(subst $(COQ_VERSION_PREFIX),,$(shell $(COQBIN)coqc ifneq ($(filter 8.4%,$(COQ_VERSION)),) # 8.4 coqprime: coqprime-8.4 +clean-coqprime: clean-coqprime-8.4 +install-coqprime: install-coqprime-8.4 else coqprime: coqprime-8.5 +clean-coqprime: clean-coqprime-8.5 +install-coqprime: install-coqprime-8.5 endif coqprime-8.4: @@ -37,6 +34,18 @@ coqprime-8.4: coqprime-8.5: $(MAKE) -C coqprime +clean-coqprime-8.4: + $(MAKE) -C coqprime-8.4 clean + +clean-coqprime-8.5: + $(MAKE) -C coqprime clean + +install-coqprime-8.4: + $(MAKE) -C coqprime-8.4 install + +install-coqprime-8.5: + $(MAKE) -C coqprime install + Makefile.coq: Makefile _CoqProject $(Q)$(COQBIN)coq_makefile -f _CoqProject -o Makefile.coq @@ -44,6 +53,8 @@ clean: Makefile.coq $(MAKE) -f Makefile.coq clean rm -f Makefile.coq +cleanall: clean clean-coqprime + install: coq Makefile.coq + $(MAKE) install-coqprime $(MAKE) -f Makefile.coq install - $(MAKE) -C coqprime install diff --git a/_CoqProject b/_CoqProject index 49fad12e0..2ff168ac8 100644 --- a/_CoqProject +++ b/_CoqProject @@ -28,6 +28,7 @@ src/CompleteEdwardsCurve/Pre.v src/Encoding/EncodingTheorems.v src/Encoding/ModularWordEncodingPre.v src/Encoding/ModularWordEncodingTheorems.v +src/Encoding/PointEncodingPre.v src/Experiments/DerivationsOptionRectLetInEncoding.v src/Experiments/EdDSARefinement.v src/Experiments/GenericFieldPow.v @@ -38,20 +39,24 @@ src/ModularArithmetic/ModularBaseSystem.v src/ModularArithmetic/ModularBaseSystemInterface.v src/ModularArithmetic/ModularBaseSystemOpt.v src/ModularArithmetic/ModularBaseSystemProofs.v +src/ModularArithmetic/Pow2Base.v +src/ModularArithmetic/Pow2BaseProofs.v src/ModularArithmetic/Pre.v src/ModularArithmetic/PrimeFieldTheorems.v src/ModularArithmetic/PseudoMersenneBaseParamProofs.v src/ModularArithmetic/PseudoMersenneBaseParams.v src/ModularArithmetic/Tutorial.v +src/ModularArithmetic/BarrettReduction/Z.v src/Spec/CompleteEdwardsCurve.v src/Spec/EdDSA.v src/Spec/Encoding.v src/Spec/ModularArithmetic.v src/Spec/ModularWordEncoding.v +src/Spec/WeierstrassCurve.v src/Specific/GF1305.v src/Specific/GF25519.v -src/Tactics/Nsatz.v src/Tactics/VerdiTactics.v +src/Tactics/Algebra_syntax/Nsatz.v src/Util/CaseUtil.v src/Util/Decidable.v src/Util/IterAssocOp.v @@ -66,3 +71,4 @@ src/Util/Tuple.v src/Util/Unit.v src/Util/WordUtil.v src/Util/ZUtil.v +src/WeierstrassCurve/Pre.v diff --git a/src/Algebra.v b/src/Algebra.v index 473571824..f083b06da 100644 --- a/src/Algebra.v +++ b/src/Algebra.v @@ -1,9 +1,10 @@ Require Import Coq.Classes.Morphisms. Require Coq.Setoids.Setoid. -Require Import Crypto.Util.Tactics Crypto.Tactics.Nsatz. +Require Import Crypto.Util.Tactics. Require Import Crypto.Util.Decidable. Require Import Crypto.Util.Notations. Require Coq.Numbers.Natural.Peano.NPeano. Local Close Scope nat_scope. Local Close Scope type_scope. Local Close Scope core_scope. +Require Crypto.Tactics.Algebra_syntax.Nsatz. Module Import ModuloCoq8485. Import NPeano Nat. @@ -345,9 +346,11 @@ Module Group. auto using associative, left_identity, right_identity, left_inverse, right_inverse. Qed. End GroupByHomomorphism. +End Group. - Section ScalarMult. - Context {G eq add zero opp} `{@group G eq add zero opp}. +Module ScalarMult. + Section ScalarMultProperties. + Context {G eq add zero} `{@monoid G eq add zero}. Context {mul:nat->G->G}. Local Infix "=" := eq : type_scope. Local Infix "=" := eq. Local Infix "+" := add. Local Infix "*" := mul. @@ -377,14 +380,8 @@ Module Group. Proof. induction n; intros. { rewrite <-mult_n_O, !scalarmult_0_l. reflexivity. } - { rewrite scalarmult_S_l, <-mult_n_Sm, <-Plus.plus_comm, scalarmult_add_l. apply cancel_left, IHn. } - Qed. - - Lemma opp_mul : forall n P, opp (n * P) = n * (opp P). - induction n; intros. - { rewrite !scalarmult_0_l, inv_id; reflexivity. } - { rewrite <-NPeano.Nat.add_1_l, Plus.plus_comm at 1. - rewrite scalarmult_add_l, scalarmult_1_l, inv_op, scalarmult_S_l, cancel_left; eauto. } + { rewrite scalarmult_S_l, <-mult_n_Sm, <-Plus.plus_comm, scalarmult_add_l. + rewrite IHn. reflexivity. } Qed. Lemma scalarmult_times_order : forall l B, l*B = zero -> forall n, (l * n) * B = zero. @@ -396,8 +393,16 @@ Module Group. rewrite (NPeano.Nat.div_mod n l Hnz) at 2. rewrite scalarmult_add_l, scalarmult_times_order, left_identity by auto. reflexivity. Qed. - End ScalarMult. -End Group. + Context {opp} {group:@group G eq add zero opp}. + + Lemma opp_mul : forall n P, opp (n * P) = n * (opp P). + induction n; intros. + { rewrite !scalarmult_0_l, Group.inv_id; reflexivity. } + { rewrite <-NPeano.Nat.add_1_l, Plus.plus_comm at 1. + rewrite scalarmult_add_l, scalarmult_1_l, Group.inv_op, scalarmult_S_l, Group.cancel_left; eauto. } + Qed. + End ScalarMultProperties. +End ScalarMult. Require Coq.nsatz.Nsatz. @@ -633,6 +638,11 @@ Module Field. End Homomorphism. End Field. +(** Tactics *) + +Ltac nsatz := Algebra_syntax.Nsatz.nsatz; dropRingSyntax. +Ltac nsatz_contradict := Algebra_syntax.Nsatz.nsatz_contradict; dropRingSyntax. + (*** Tactics for manipulating field equations *) Require Import Coq.setoid_ring.Field_tac. @@ -658,7 +668,7 @@ Ltac field_nonzero_mul_split := => apply IntegralDomain.mul_nonzero_nonzero_iff in H; destruct H end. -Ltac common_denominator := +Ltac field_simplify_eq_if_div := let fld := guess_field in lazymatch type of fld with field (div:=?div) => @@ -669,7 +679,7 @@ Ltac common_denominator := end. (** We jump through some hoops to ensure that the side-conditions come late *) -Ltac common_denominator_in_cycled_side_condition_order H := +Ltac field_simplify_eq_if_div_in_cycled_side_condition_order H := let fld := guess_field in lazymatch type of fld with field (div:=?div) => @@ -679,15 +689,11 @@ Ltac common_denominator_in_cycled_side_condition_order H := end end. -Ltac common_denominator_in H := +Ltac field_simplify_eq_if_div_in H := side_conditions_before_to_side_conditions_after - common_denominator_in_cycled_side_condition_order + field_simplify_eq_if_div_in_cycled_side_condition_order H. -Ltac common_denominator_all := - common_denominator; - repeat match goal with [H: _ |- _ _ _ ] => progress common_denominator_in H end. - (** Now we have more conservative versions that don't simplify non-division structure. *) Ltac deduplicate_nonfraction_pieces mul := repeat match goal with @@ -765,11 +771,11 @@ Ltac set_nonfraction_pieces := ltac:(fun T' => change T'); deduplicate_nonfraction_pieces mul end. -Ltac default_conservative_common_denominator_nonzero_tac := +Ltac default_common_denominator_nonzero_tac := repeat apply conj; try first [ assumption | intro; field_nonzero_mul_split; tauto ]. -Ltac conservative_common_denominator_in H := +Ltac common_denominator_in H := idtac; let fld := guess_field in let div := lazymatch type of fld with @@ -779,13 +785,13 @@ Ltac conservative_common_denominator_in H := lazymatch type of H with | appcontext[div] => set_nonfraction_pieces_in H; - common_denominator_in H; + field_simplify_eq_if_div_in H; [ - | default_conservative_common_denominator_nonzero_tac.. ]; + | default_common_denominator_nonzero_tac.. ]; repeat match goal with H := _ |- _ => subst H end | ?T => fail 0 "no division in" H ":" T end. -Ltac conservative_common_denominator := +Ltac common_denominator := idtac; let fld := guess_field in let div := lazymatch type of fld with @@ -795,14 +801,14 @@ Ltac conservative_common_denominator := lazymatch goal with | |- appcontext[div] => set_nonfraction_pieces; - common_denominator; + field_simplify_eq_if_div; [ - | default_conservative_common_denominator_nonzero_tac.. ]; + | default_common_denominator_nonzero_tac.. ]; repeat match goal with H := _ |- _ => subst H end | |- ?G => fail 0 "no division in goal" G end. -Ltac conservative_common_denominator_inequality_in H := +Ltac common_denominator_inequality_in H := let HT := type of H in lazymatch HT with | not (?R _ _) => idtac @@ -817,8 +823,8 @@ Ltac conservative_common_denominator_inequality_in H := cut (not HT'); subst HT'; [ intro H; clear H' | let H'' := fresh in - intro H''; apply H'; conservative_common_denominator; [ eexact H'' | .. ] ]. -Ltac conservative_common_denominator_inequality := + intro H''; apply H'; common_denominator; [ eexact H'' | .. ] ]. +Ltac common_denominator_inequality := let G := get_goal in lazymatch G with | not (?R _ _) => idtac @@ -832,37 +838,37 @@ Ltac conservative_common_denominator_inequality := assert (H' : not HT'); subst HT'; [ | let HG := fresh in - intros HG; apply H'; conservative_common_denominator_in HG; [ eexact HG | .. ] ]. + intros HG; apply H'; common_denominator_in HG; [ eexact HG | .. ] ]. -Ltac conservative_common_denominator_hyps := +Ltac common_denominator_hyps := try match goal with | [H: _ |- _ ] - => progress conservative_common_denominator_in H; - [ conservative_common_denominator_hyps + => progress common_denominator_in H; + [ common_denominator_hyps | .. ] end. -Ltac conservative_common_denominator_inequality_hyps := +Ltac common_denominator_inequality_hyps := try match goal with | [H: _ |- _ ] - => progress conservative_common_denominator_inequality_in H; - [ conservative_common_denominator_inequality_hyps + => progress common_denominator_inequality_in H; + [ common_denominator_inequality_hyps | .. ] end. -Ltac conservative_common_denominator_all := - try conservative_common_denominator; - [ try conservative_common_denominator_hyps +Ltac common_denominator_all := + try common_denominator; + [ try common_denominator_hyps | .. ]. -Ltac conservative_common_denominator_inequality_all := - try conservative_common_denominator_inequality; - [ try conservative_common_denominator_inequality_hyps +Ltac common_denominator_inequality_all := + try common_denominator_inequality; + [ try common_denominator_inequality_hyps | .. ]. -Ltac conservative_common_denominator_equality_inequality_all := - conservative_common_denominator_all; - [ conservative_common_denominator_inequality_all +Ltac common_denominator_equality_inequality_all := + common_denominator_all; + [ common_denominator_inequality_all | .. ]. Inductive field_simplify_done {T} : T -> Type := @@ -974,26 +980,6 @@ Ltac neq01 := |apply one_neq_zero |apply Group.opp_one_neq_zero]. -Ltac conservative_field_algebra := - intros; - conservative_common_denominator_all; - try (nsatz; dropRingSyntax); - repeat (apply conj); - try solve - [neq01 - |trivial - |apply Ring.opp_nonzero_nonzero;trivial]. - -Ltac field_algebra := - intros; - common_denominator_all; - try (nsatz; dropRingSyntax); - repeat (apply conj); - try solve - [neq01 - |trivial - |apply Ring.opp_nonzero_nonzero;trivial]. - Ltac combine_field_inequalities_step := match goal with | [ H : not (?R ?x ?zero), H' : not (?R ?x' ?zero) |- _ ] @@ -1038,32 +1024,49 @@ Ltac super_nsatz_post_clean_inequalities := try assumption; prensatz_contradict; nsatz_inequality_to_equality; try nsatz. +Ltac nsatz_equality_to_inequality_by_decide_equality := + lazymatch goal with + | [ H : not (?R _ _) |- ?R _ _ ] => idtac + | [ H : (?R _ _ -> False)%type |- ?R _ _ ] => idtac + | [ |- ?R _ _ ] => fail 0 "No hypothesis exists which negates the relation" R + | [ |- ?G ] => fail 0 "The goal is not a binary relation:" G + end; + lazymatch goal with + | [ |- ?R ?x ?y ] + => destruct (@dec (R x y) _); [ assumption | exfalso ] + end. (** Handles inequalities and fractions *) -Ltac super_nsatz := +Ltac super_nsatz_internal nsatz_alternative := (* [nsatz] gives anomalies on duplicate hypotheses, so we strip them *) clear_algebraic_duplicates; prensatz_contradict; (* Each goal left over by [prensatz_contradict] is separate (and there might not be any), so we handle them all separately *) - [ try conservative_common_denominator_equality_inequality_all; - [ try nsatz_inequality_to_equality; try nsatz; - (* [nstaz] might leave over side-conditions; we handle them if they are inequalities *) - try super_nsatz_post_clean_inequalities + [ try common_denominator_equality_inequality_all; + [ try nsatz_inequality_to_equality; + try first [ nsatz; + (* [nstaz] might leave over side-conditions; we handle them if they are inequalities *) + try super_nsatz_post_clean_inequalities + | nsatz_alternative ] | super_nsatz_post_clean_inequalities.. ].. ]. +Ltac super_nsatz := + super_nsatz_internal + (* if [nsatz] fails, we try turning the goal equality into an inequality and trying again *) + ltac:(nsatz_equality_to_inequality_by_decide_equality; + super_nsatz_internal idtac). + Section ExtraLemmas. Context {F eq zero one opp add sub mul inv div} `{F_field:field F eq zero one opp add sub mul inv div}. Local Infix "+" := add. Local Infix "*" := mul. Local Infix "-" := sub. Local Infix "/" := div. Local Notation "0" := zero. Local Notation "1" := one. Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Example _only_two_square_roots_test x y : x * x = y * y -> x <> opp y -> x = y. + Proof. intros; super_nsatz. Qed. + Lemma only_two_square_roots' x y : x * x = y * y -> x <> y -> x <> opp y -> False. - Proof. - intros. - canonicalize_field_equalities; canonicalize_field_inequalities. - assert (H' : (x + y) * (x - y) <> 0) by (apply mul_nonzero_nonzero; assumption). - apply H'; nsatz. - Qed. + Proof. intros; super_nsatz. Qed. Lemma only_two_square_roots x y z : x * x = z -> y * y = z -> x <> y -> x <> opp y -> False. Proof. @@ -1157,10 +1160,10 @@ Section Example. Add Field _ExampleField : (Field.field_theory_for_stdlib_tactic (T:=F)). Example _example_nsatz x y : 1+1 <> 0 -> x + y = 0 -> x - y = 0 -> x = 0. - Proof. field_algebra. Qed. + Proof. intros. nsatz. Qed. Example _example_field_nsatz x y z : y <> 0 -> x/y = z -> z*y + y = x + y. - Proof. intros; subst; field_algebra. Qed. + Proof. intros. super_nsatz. Qed. Example _example_nonzero_nsatz_contradict x y : x * y = 1 -> not (x = 0). Proof. intros. intro. nsatz_contradict. Qed. diff --git a/src/BaseSystem.v b/src/BaseSystem.v index 743cdfde8..840713168 100644 --- a/src/BaseSystem.v +++ b/src/BaseSystem.v @@ -9,7 +9,7 @@ Local Open Scope Z. Class BaseVector (base : list Z):= { base_positive : forall b, In b base -> b > 0; (* nonzero would probably work too... *) - b0_1 : forall x, nth_default x base 0 = 1; + b0_1 : forall x, nth_default x base 0 = 1; (** TODO(jadep,jgross): change to [nth_error base 0 = Some 1], then use [nth_error_value_eq_nth_default] to prove a [forall x, nth_default x base 0 = 1] as a lemma *) base_good : forall i j, (i+j < length base)%nat -> let b := nth_default 0 base in @@ -34,8 +34,17 @@ Section BaseSystem. Definition accumulate p acc := fst p * snd p + acc. Definition decode' bs u := fold_right accumulate 0 (combine u bs). Definition decode := decode' base. - (* Does not carry; z becomes the lowest and only digit. *) - Definition encode (z : Z) := z :: nil. + + (* i is current index, counts down *) + Fixpoint encode' z max i : digits := + match i with + | O => nil + | S i' => let b := nth_default max base in + encode' z max i' ++ ((z mod (b i)) / (b i')) :: nil + end. + + (* max must be greater than input; this is used to truncate last digit *) + Definition encode z max := encode' z max (length base). Lemma decode'_truncate : forall bs us, decode' bs us = decode' bs (firstn (length bs) us). Proof. @@ -111,7 +120,7 @@ Section PolynomialBaseCoefs. rewrite in_map_iff in *. destruct H; destruct H. subst. - apply pos_pow_nat_pos. + apply Z.pos_pow_nat_pos. Qed. Lemma poly_base_defn : forall i, (i < length poly_base)%nat -> diff --git a/src/BaseSystemProofs.v b/src/BaseSystemProofs.v index 85835aabe..eb7f31ba6 100644 --- a/src/BaseSystemProofs.v +++ b/src/BaseSystemProofs.v @@ -78,10 +78,24 @@ Section BaseSystemProofs. induction bs; destruct us; destruct vs; boring; ring. Qed. - Lemma encode_rep : forall z, decode base (encode z) = z. + Lemma nth_default_base_nonzero : forall d, d <> 0 -> + forall i, nth_default d base i <> 0. Proof. - pose proof base_eq_1cons. - unfold decode, encode; destruct z; boring. + intros. + rewrite nth_default_eq. + destruct (nth_in_or_default i base d). + + auto using Z.positive_is_nonzero, base_positive. + + congruence. + Qed. + + Lemma nth_default_base_pos : forall d, 0 < d -> + forall i, 0 < nth_default d base i. + Proof. + intros. + rewrite nth_default_eq. + destruct (nth_in_or_default i base d). + + apply Z.gt_lt; auto using base_positive. + + congruence. Qed. Lemma mul_each_base : forall us bs c, @@ -177,7 +191,7 @@ Section BaseSystemProofs. Lemma nth_error_base_nonzero : forall n x, nth_error base n = Some x -> x <> 0. Proof. - eauto using (@nth_error_value_In Z), Zgt0_neq0, base_positive. + eauto using (@nth_error_value_In Z), Z.gt0_neq0, base_positive. Qed. Hint Rewrite plus_0_r. @@ -544,5 +558,52 @@ Section BaseSystemProofs. apply length0_nil; rewrite <-rev_length, rev_nil. reflexivity. Qed. + Definition encode'_zero z max : encode' base z max 0%nat = nil := eq_refl. + Definition encode'_succ z max i : encode' base z max (S i) = + encode' base z max i ++ ((z mod (nth_default max base (S i))) / (nth_default max base i)) :: nil := eq_refl. + Opaque encode'. + Hint Resolve encode'_zero encode'_succ. + + Lemma encode'_length : forall z max i, length (encode' base z max i) = i. + Proof. + induction i; auto. + rewrite encode'_succ, app_length, IHi. + cbv [length]. + omega. + Qed. + + (* States that each element of the base is a positive integer multiple of the previous + element, and that max is a positive integer multiple of the last element. Ideally this + would have a better name. *) + Definition base_max_succ_divide max := forall i, (S i <= length base)%nat -> + Z.divide (nth_default max base i) (nth_default max base (S i)). + + Lemma encode'_spec : forall z max, 0 < max -> + base_max_succ_divide max -> forall i, (i <= length base)%nat -> + decode' base (encode' base z max i) = z mod (nth_default max base i). + Proof. + induction i; intros. + + rewrite encode'_zero, b0_1, Z.mod_1_r. + apply decode_nil. + + rewrite encode'_succ, set_higher. + rewrite IHi by omega. + rewrite encode'_length, (Z.add_comm (z mod nth_default max base i)). + replace (nth_default 0 base i) with (nth_default max base i) by + (rewrite !nth_default_eq; apply nth_indep; omega). + match goal with H1 : base_max_succ_divide _, H2 : (S i <= length base)%nat, H3 : 0 < max |- _ => + specialize (H1 i H2); + rewrite (Znumtheory.Zmod_div_mod _ _ _ (nth_default_base_pos _ H _) + (nth_default_base_pos _ H _) H0) end. + rewrite <-Z.div_mod by (apply Z.positive_is_nonzero, Z.lt_gt; auto using nth_default_base_pos). + reflexivity. + Qed. + + Lemma encode_rep : forall z max, 0 <= z < max -> + base_max_succ_divide max -> decode base (encode base z max) = z. + Proof. + unfold encode; intros. + rewrite encode'_spec, nth_default_out_of_bounds by (omega || auto). + apply Z.mod_small; omega. + Qed. -End BaseSystemProofs. +End BaseSystemProofs.
\ No newline at end of file diff --git a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v index 0afc07c5d..a160d8dab 100644 --- a/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v +++ b/src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v @@ -1,16 +1,15 @@ Require Export Crypto.Spec.CompleteEdwardsCurve. -Require Import Crypto.Algebra Crypto.Tactics.Nsatz. +Require Import Crypto.Algebra Crypto.Algebra. Require Import Crypto.CompleteEdwardsCurve.Pre. Require Import Coq.Logic.Eqdep_dec. Require Import Crypto.Tactics.VerdiTactics. Require Import Coq.Classes.Morphisms. Require Import Relation_Definitions. -Require Import Crypto.Util.Tuple. -Require Import Crypto.Util.Notations. +Require Import Crypto.Util.Tuple Crypto.Util.Notations Crypto.Util.Tactics. Module E. - Import Group Ring Field CompleteEdwardsCurve.E. + Import Group ScalarMult Ring Field CompleteEdwardsCurve.E. Section CompleteEdwardsCurveTheorems. Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a d} {field:@field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} @@ -28,65 +27,6 @@ Module E. Definition eq (P Q:point) := fieldwise (n:=2) Feq (coordinates P) (coordinates Q). Infix "=" := eq : E_scope. - (* TODO: decide whether we still want something like this, then port - Local Ltac t := - unfold point_eqb; - repeat match goal with - | _ => progress intros - | _ => progress simpl in * - | _ => progress subst - | [P:E.point |- _ ] => destruct P - | [x: (F q * F q)%type |- _ ] => destruct x - | [H: _ /\ _ |- _ ] => destruct H - | [H: _ |- _ ] => rewrite Bool.andb_true_iff in H - | [H: _ |- _ ] => apply F_eqb_eq in H - | _ => rewrite F_eqb_refl - end; eauto. - - Lemma point_eqb_sound : forall p1 p2, point_eqb p1 p2 = true -> p1 = p2. - Proof. - t. - Qed. - - Lemma point_eqb_complete : forall p1 p2, p1 = p2 -> point_eqb p1 p2 = true. - Proof. - t. - Qed. - - Lemma point_eqb_neq : forall p1 p2, point_eqb p1 p2 = false -> p1 <> p2. - Proof. - intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition. - apply point_eqb_complete in H0; congruence. - Qed. - - Lemma point_eqb_neq_complete : forall p1 p2, p1 <> p2 -> point_eqb p1 p2 = false. - Proof. - intros. destruct (point_eqb p1 p2) eqn:Hneq; intuition. - apply point_eqb_sound in Hneq. congruence. - Qed. - - Lemma point_eqb_refl : forall p, point_eqb p p = true. - Proof. - t. - Qed. - - Definition point_eq_dec (p1 p2:E.point) : {p1 = p2} + {p1 <> p2}. - destruct (point_eqb p1 p2) eqn:H; match goal with - | [ H: _ |- _ ] => apply point_eqb_sound in H - | [ H: _ |- _ ] => apply point_eqb_neq in H - end; eauto. - Qed. - - Lemma point_eqb_correct : forall p1 p2, point_eqb p1 p2 = if point_eq_dec p1 p2 then true else false. - Proof. - intros. destruct (point_eq_dec p1 p2); eauto using point_eqb_complete, point_eqb_neq_complete. - Qed. - *) - - (* TODO: move to util *) - Lemma decide_and : forall P Q, {P}+{not P} -> {Q}+{not Q} -> {P/\Q}+{not(P/\Q)}. - Proof. intros; repeat match goal with [H:{_}+{_}|-_] => destruct H end; intuition. Qed. - Ltac destruct_points := repeat match goal with | [ p : point |- _ ] => @@ -96,30 +36,55 @@ Module E. destruct p as [[x y] pf] end. - Ltac expand_opp := - rewrite ?mul_opp_r, ?mul_opp_l, ?ring_sub_definition, ?inv_inv, <-?ring_sub_definition. - - Local Hint Resolve char_gt_2. - Local Hint Resolve nonzero_a. - Local Hint Resolve square_a. - Local Hint Resolve nonsquare_d. - Local Hint Resolve @edwardsAddCompletePlus. - Local Hint Resolve @edwardsAddCompleteMinus. - - Local Obligation Tactic := intros; destruct_points; simpl; field_algebra. + Local Obligation Tactic := intros; destruct_points; simpl; super_nsatz. Program Definition opp (P:point) : point := exist _ (let '(x, y) := coordinates P in (Fopp x, y) ) _. + (* all nonzero-denominator goals here require proofs that are not + trivially implied by field axioms. Posing all such proofs at once + and then solving the nonzero-denominator goal using [super_nsatz] + is too slow because the context contains many assumed nonzero + expressions and the product of all of them is a very large + polynomial. However, we never need to use more than one + nonzero-ness assumption for a given nonzero-denominator goal, so + we can try them separately one-by-one. *) + + Ltac apply_field_nonzero H := + match goal with |- not (Feq _ 0) => idtac | _ => fail "not a nonzero goal" end; + try solve [exact H]; + let Hx := fresh "H" in + intro Hx; + apply H; + try common_denominator; + [rewrite <-Hx; ring | ..]. + Ltac bash_step := + let addCompletePlus := constr:(edwardsAddCompletePlus(char_gt_2:=char_gt_2)(d_nonsquare:=nonsquare_d)(a_square:=square_a)(a_nonzero:=nonzero_a)) in + let addCompleteMinus := constr:(edwardsAddCompleteMinus(char_gt_2:=char_gt_2)(d_nonsquare:=nonsquare_d)(a_square:=square_a)(a_nonzero:=nonzero_a)) in + let addOnCurve := constr:(unifiedAdd'_onCurve(char_gt_2:=char_gt_2)(d_nonsquare:=nonsquare_d)(a_square:=square_a)(a_nonzero:=nonzero_a)) in match goal with | |- _ => progress intros | [H: _ /\ _ |- _ ] => destruct H + | [H: ?a = ?b |- _ ] => is_var a; is_var b; repeat rewrite <-H in *; clear H b (* fast path *) | |- _ => progress destruct_points | |- _ => progress cbv [fst snd coordinates proj1_sig eq fieldwise fieldwise' add zero opp] in * | |- _ => split - | |- Feq _ _ => field_algebra - | |- _ <> 0 => expand_opp; solve [nsatz_nonzero|eauto 6] - | |- Decidable.Decidable _ => solve [ typeclasses eauto ] + | [H:Feq (a*_^2+_^2) (1 + d*_^2*_^2) |- _ <> 0] + => apply_field_nonzero (addCompletePlus _ _ _ _ H H) || + apply_field_nonzero (addCompleteMinus _ _ _ _ H H) + | [A:Feq (a*_^2+_^2) (1 + d*_^2*_^2), + B:Feq (a*_^2+_^2) (1 + d*_^2*_^2) |- _ <> 0] + => apply_field_nonzero (addCompletePlus _ _ _ _ A B) || + apply_field_nonzero (addCompleteMinus _ _ _ _ A B) + | [A:Feq (a*_^2+_^2) (1 + d*_^2*_^2), + B:Feq (a*_^2+_^2) (1 + d*_^2*_^2), + C:Feq (a*_^2+_^2) (1 + d*_^2*_^2) |- _ <> 0] + => apply_field_nonzero (addCompleteMinus _ _ _ _ A (addOnCurve (_, _) (_, _) B C)) || + apply_field_nonzero (addCompletePlus _ _ _ _ A (addOnCurve (_, _) (_, _) B C)) + | |- ?x <> 0 => let H := fresh "H" in assert (x = 1) as H by ring; rewrite H; exact one_neq_zero + | |- Feq _ _ => progress common_denominator + | |- Feq _ _ => nsatz + | |- _ => exact _ (* typeclass instances *) end. Ltac bash := repeat bash_step. @@ -127,25 +92,19 @@ Module E. Global Instance Proper_add : Proper (eq==>eq==>eq) add. Proof. bash. Qed. Global Instance Proper_opp : Proper (eq==>eq) opp. Proof. bash. Qed. Global Instance Proper_coordinates : Proper (eq==>fieldwise (n:=2) Feq) coordinates. Proof. bash. Qed. - Global Instance edwards_acurve_abelian_group : abelian_group (eq:=eq)(op:=add)(id:=zero)(inv:=opp). Proof. bash. - (* TODO: port denominator-nonzero proofs for associativity *) - match goal with | |- _ <> 0 => admit end. - match goal with | |- _ <> 0 => admit end. - match goal with | |- _ <> 0 => admit end. - match goal with | |- _ <> 0 => admit end. - Admitted. + Qed. Global Instance Proper_mul : Proper (Logic.eq==>eq==>eq) mul. Proof. - intros n m Hnm P Q HPQ. rewrite <-Hnm; clear Hnm m. - induction n; simpl; rewrite ?IHn, ?HPQ; reflexivity. + intros n n'; repeat intro; subst n'. + induction n; (reflexivity || eapply Proper_add; eauto). Qed. Global Instance mul_is_scalarmult : @is_scalarmult point eq add zero mul. - Proof. split; intros; reflexivity || typeclasses eauto. Qed. + Proof. unfold mul; split; intros; (reflexivity || exact _). Qed. Section PointCompression. Local Notation "x ^ 2" := (x*x). @@ -155,12 +114,13 @@ Module E. Proof. intros ? eq_zero. destruct square_a as [sqrt_a sqrt_a_id]; rewrite <- sqrt_a_id in eq_zero. - destruct (eq_dec y 0); [apply nonzero_a|apply nonsquare_d with (sqrt_a/y)]; field_algebra. + destruct (eq_dec y 0); [apply nonzero_a | apply nonsquare_d with (sqrt_a/y)]; super_nsatz. Qed. Lemma solve_correct : forall x y, onCurve (x, y) <-> (x^2 = solve_for_x2 y). Proof. - unfold solve_for_x2; simpl; split; intros; field_algebra; auto using a_d_y2_nonzero. + unfold solve_for_x2; simpl; split; intros; + (common_denominator_all; [nsatz | auto using a_d_y2_nonzero]). Qed. End PointCompression. End CompleteEdwardsCurveTheorems. diff --git a/src/CompleteEdwardsCurve/ExtendedCoordinates.v b/src/CompleteEdwardsCurve/ExtendedCoordinates.v index 364d7f9ec..d8efb82f3 100644 --- a/src/CompleteEdwardsCurve/ExtendedCoordinates.v +++ b/src/CompleteEdwardsCurve/ExtendedCoordinates.v @@ -1,6 +1,6 @@ Require Export Crypto.Spec.CompleteEdwardsCurve. -Require Import Crypto.Algebra Crypto.Tactics.Nsatz. +Require Import Crypto.Algebra Crypto.Algebra. Require Import Crypto.CompleteEdwardsCurve.Pre Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. Require Import Coq.Logic.Eqdep_dec. Require Import Crypto.Tactics.VerdiTactics. @@ -33,6 +33,7 @@ Module Extended. Create HintDb bash discriminated. Local Hint Unfold E.eq fst snd fieldwise fieldwise' coordinates E.coordinates proj1_sig Pre.onCurve : bash. Ltac bash := + pose proof E.char_gt_2; repeat match goal with | |- Proper _ _ => intro | _ => progress intros @@ -43,15 +44,11 @@ Module Extended. | |- _ /\ _ => split | _ => solve [neq01] | _ => solve [eauto] - | _ => solve [intuition] + | _ => solve [intuition eauto] | _ => solve [etransitivity; eauto] - | |- Feq _ _ => field_algebra - | |- _ <> 0 => apply mul_nonzero_nonzero - | [ H : _ <> 0 |- _ <> 0 ] => - intro; apply H; - field_algebra; - solve [ apply Ring.opp_nonzero_nonzero, E.char_gt_2 - | apply E.char_gt_2] + | |- _*_ <> 0 => apply mul_nonzero_nonzero + | [H: _ |- _ ] => solve [intro; apply H; super_nsatz] + | |- Feq _ _ => super_nsatz end. Obligation Tactic := bash. @@ -63,8 +60,7 @@ Module Extended. (let '(X,Y,Z,T) := coordinates P in ((X/Z), (Y/Z))) _. Definition eq (P Q:point) := E.eq (to_twisted P) (to_twisted Q). - Global Instance DecidableRel_eq : Decidable.DecidableRel eq. - Proof. typeclasses eauto. Qed. + Global Instance DecidableRel_eq : Decidable.DecidableRel eq := _. Local Hint Unfold from_twisted to_twisted eq : bash. diff --git a/src/CompleteEdwardsCurve/Pre.v b/src/CompleteEdwardsCurve/Pre.v index be423c05c..c74e9a321 100644 --- a/src/CompleteEdwardsCurve/Pre.v +++ b/src/CompleteEdwardsCurve/Pre.v @@ -1,5 +1,5 @@ Require Import Coq.Classes.Morphisms. Require Coq.Setoids.Setoid. -Require Import Crypto.Algebra Crypto.Tactics.Nsatz. +Require Import Crypto.Algebra Crypto.Algebra. Require Import Crypto.Util.Notations. Generalizable All Variables. @@ -29,6 +29,15 @@ Section Pre. Ltac use_sqrt_a := destruct a_square as [sqrt_a a_square']; rewrite <-a_square' in *. + Lemma onCurve_subst : forall x1 x2 y1 y2, and (eq x1 y1) (eq x2 y2) -> onCurve (x1, x2) -> + onCurve (y1, y2). + Proof. + unfold onCurve. + intros ? ? ? ? [eq_1 eq_2] ?. + rewrite eq_1, eq_2 in *. + assumption. + Qed. + Lemma edwardsAddComplete' x1 y1 x2 y2 : onCurve (pair x1 y1) -> onCurve (pair x2 y2) -> @@ -41,24 +50,25 @@ Section Pre. => apply d_nonsquare with (sqrt_d:= (f (sqrt_a * x1) (d * x1 * x2 * y1 * y2 * y1)) /(f (sqrt_a * x2) y2 * x1 * y1 )) | _ => apply a_nonzero - end; field_algebra; auto using Ring.opp_nonzero_nonzero; nsatz_contradict. + end; super_nsatz. Qed. Lemma edwardsAddCompletePlus x1 y1 x2 y2 : onCurve (x1, y1) -> onCurve (x2, y2) -> (1 + d*x1*x2*y1*y2) <> 0. - Proof. intros H1 H2 ?. apply (edwardsAddComplete' _ _ _ _ H1 H2); field_algebra. Qed. + Proof. intros H1 H2 ?. apply (edwardsAddComplete' _ _ _ _ H1 H2); super_nsatz. Qed. Lemma edwardsAddCompleteMinus x1 y1 x2 y2 : onCurve (x1, y1) -> onCurve (x2, y2) -> (1 - d*x1*x2*y1*y2) <> 0. - Proof. intros H1 H2 ?. apply (edwardsAddComplete' _ _ _ _ H1 H2); field_algebra. Qed. + Proof. intros H1 H2 ?. apply (edwardsAddComplete' _ _ _ _ H1 H2); super_nsatz. Qed. - Lemma zeroOnCurve : onCurve (0, 1). Proof. simpl. field_algebra. Qed. + Lemma zeroOnCurve : onCurve (0, 1). Proof. simpl. super_nsatz. Qed. Lemma unifiedAdd'_onCurve : forall P1 P2, onCurve P1 -> onCurve P2 -> onCurve (unifiedAdd' P1 P2). Proof. - unfold onCurve, unifiedAdd'; intros [x1 y1] [x2 y2] H1 H2. - field_algebra; auto using edwardsAddCompleteMinus, edwardsAddCompletePlus. + unfold onCurve, unifiedAdd'; intros [x1 y1] [x2 y2] ? ?. + common_denominator; [ | auto using edwardsAddCompleteMinus, edwardsAddCompletePlus..]. + nsatz. Qed. End Pre. diff --git a/src/Encoding/ModularWordEncodingTheorems.v b/src/Encoding/ModularWordEncodingTheorems.v index 41b75e216..033e99665 100644 --- a/src/Encoding/ModularWordEncodingTheorems.v +++ b/src/Encoding/ModularWordEncodingTheorems.v @@ -43,12 +43,12 @@ Section SignBit. pose proof (F_opp_spec x) as opp_spec_x. apply F_eq in opp_spec_x. rewrite FieldToZ_add in opp_spec_x. - rewrite <-opp_spec_x, Z_odd_mod in sign_zero by (pose proof prime_ge_2 m prime_m; omega). - replace (Z.odd m) with true in sign_zero by (destruct (ZUtil.prime_odd_or_2 m prime_m); auto || omega). + rewrite <-opp_spec_x, Z.odd_mod in sign_zero by (pose proof prime_ge_2 m prime_m; omega). + replace (Z.odd m) with true in sign_zero by (destruct (Z.prime_odd_or_2 m prime_m); auto || omega). rewrite Z.odd_add, F_FieldToZ_add_opp, Z.div_same, Bool.xorb_true_r in sign_zero by assumption || omega. apply Bool.xorb_eq. rewrite <-Bool.negb_xorb_l. assumption. Qed. -End SignBit.
\ No newline at end of file +End SignBit. diff --git a/src/Encoding/PointEncodingPre.v b/src/Encoding/PointEncodingPre.v new file mode 100644 index 000000000..2ad567c92 --- /dev/null +++ b/src/Encoding/PointEncodingPre.v @@ -0,0 +1,395 @@ +Require Import Coq.ZArith.ZArith Coq.ZArith.Znumtheory. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Program.Equality. +Require Import Crypto.CompleteEdwardsCurve.Pre. +Require Import Crypto.CompleteEdwardsCurve.CompleteEdwardsCurveTheorems. +Require Import Bedrock.Word. +Require Import Crypto.Encoding.ModularWordEncodingTheorems. +Require Import Crypto.Tactics.VerdiTactics. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Algebra. + +Require Import Crypto.Spec.Encoding Crypto.Spec.ModularWordEncoding Crypto.Spec.ModularArithmetic. + +Generalizable All Variables. +Section PointEncodingPre. + Context {F eq zero one opp add sub mul inv div} `{field F eq zero one opp add sub mul inv div}. + Local Infix "==" := eq (at level 30) : type_scope. + Local Notation "a !== b" := (not (a == b)) (at level 30): type_scope. + Local Notation "0" := zero. Local Notation "1" := one. + Local Infix "+" := add. Local Infix "*" := mul. + Local Infix "-" := sub. Local Infix "/" := div. + Local Notation "x '^' 2" := (x*x) (at level 30). + + Add Field EdwardsCurveField : (Field.field_theory_for_stdlib_tactic (T:=F)). + + Context {eq_dec:forall x y : F, {x==y}+{x==y->False}}. + Definition F_eqb x y := if eq_dec x y then true else false. + Lemma F_eqb_iff : forall x y, F_eqb x y = true <-> x == y. + Proof. + unfold F_eqb; intros; destruct (eq_dec x y); split; auto; discriminate. + Qed. + + Context {a d:F} {prm:@E.twisted_edwards_params F eq zero one add mul a d}. + Local Notation point := (@E.point F eq one add mul a d). + Local Notation onCurve := (@onCurve F eq one add mul a d). + Local Notation solve_for_x2 := (@E.solve_for_x2 F one sub mul div a d). + + Context {sz : nat} (sz_nonzero : (0 < sz)%nat). + Context {sqrt : F -> F} (sqrt_square : forall x root, x == (root ^2) -> sqrt x == root) + (sqrt_subst : forall x y, x == y -> sqrt x == sqrt y). + Context (FEncoding : canonical encoding of F as (word sz)). + Context {sign_bit : F -> bool} (sign_bit_zero : forall x, x == 0 -> Logic.eq (sign_bit x) false) + (sign_bit_opp : forall x, x !== 0 -> Logic.eq (negb (sign_bit x)) (sign_bit (opp x))) + (sign_bit_subst : forall x y, x == y -> sign_bit x = sign_bit y). + + Definition sqrt_ok (a : F) := (sqrt a) ^ 2 == a. + + Lemma square_sqrt : forall y root, y == (root ^2) -> + sqrt_ok y. + Proof. + unfold sqrt_ok; intros ? ? root2_y. + pose proof root2_y. + apply sqrt_square in root2_y. + rewrite root2_y. + symmetry; assumption. + Qed. + + Lemma solve_onCurve: forall x y : F, onCurve (x,y) -> + onCurve (sqrt (solve_for_x2 y), y). + Proof. + intros. + apply E.solve_correct. + eapply square_sqrt. + symmetry. + apply E.solve_correct; eassumption. + Qed. + + (* TODO : move? *) + Lemma square_opp : forall x : F, (opp x ^2) == (x ^2). + Proof. + intros. ring. + Qed. + + Lemma solve_opp_onCurve: forall x y : F, onCurve (x,y) -> + onCurve (opp (sqrt (solve_for_x2 y)), y). + Proof. + intros. + apply E.solve_correct. + etransitivity; [ apply square_opp | ]. + eapply square_sqrt. + symmetry. + apply E.solve_correct; eassumption. + Qed. + + Definition point_enc_coordinates (p : (F * F)) : Word.word (S sz) := let '(x,y) := p in + Word.WS (sign_bit x) (enc y). + + Let point_enc (p : point) : Word.word (S sz) := point_enc_coordinates (E.coordinates p). + + Definition point_dec_coordinates (w : Word.word (S sz)) : option (F * F) := + match dec (Word.wtl w) with + | None => None + | Some y => let x2 := solve_for_x2 y in + let x := sqrt x2 in + if eq_dec (x ^ 2) x2 + then + let p := (if Bool.eqb (whd w) (sign_bit x) then x else opp x, y) in + if (andb (F_eqb x 0) (whd w)) + then None (* special case for 0, since its opposite has the same sign; if the sign bit of 0 is 1, produce None.*) + else Some p + else None + end. + + (* Definition of product equality parameterized over equality of underlying types *) + Definition prod_eq {A B} eqA eqB (x y : (A * B)) := let (xA,xB) := x in let (yA,yB) := y in + (eqA xA yA) /\ (eqB xB yB). + + Lemma prod_eq_dec : forall {A eq} (A_eq_dec : forall a a' : A, {eq a a'} + {not (eq a a')}) + (x y : (A * A)), {prod_eq eq eq x y} + {not (prod_eq eq eq x y)}. + Proof. + intros. + destruct x as [x1 x2]. + destruct y as [y1 y2]. + match goal with + | |- {prod_eq _ _ (?x1, ?x2) (?y1,?y2)} + {not (prod_eq _ _ (?x1, ?x2) (?y1,?y2))} => + destruct (A_eq_dec x1 y1); destruct (A_eq_dec x2 y2) end; + unfold prod_eq; intuition. + Qed. + + Definition option_eq {A} eq (x y : option A) := + match x with + | None => y = None + | Some ax => match y with + | None => False + | Some ay => eq ax ay + end + end. + + Lemma option_eq_dec : forall {A eq} (A_eq_dec : forall a a' : A, {eq a a'} + {not (eq a a')}) + (x y : option A), {option_eq eq x y} + {not (option_eq eq x y)}. + Proof. + unfold option_eq; intros; destruct x as [ax|], y as [ay|]; try tauto; auto. + right; congruence. + Qed. + Definition option_coordinates_eq := option_eq (prod_eq eq eq). + + Lemma option_coordinates_eq_NS : forall x, option_coordinates_eq None (Some x) -> False. + Proof. + unfold option_coordinates_eq, option_eq. + intros; discriminate. + Qed. + + Lemma inversion_option_coordinates_eq : forall x y, + option_coordinates_eq (Some x) (Some y) -> prod_eq eq eq x y. + Proof. + unfold option_coordinates_eq, option_eq; intros; assumption. + Qed. + + Lemma prod_eq_onCurve : forall p q : F * F, prod_eq eq eq p q -> + onCurve p -> onCurve q. + Proof. + unfold prod_eq; intros. + destruct p; destruct q. + eauto using onCurve_subst. + Qed. + + Lemma option_coordinates_eq_iff : forall x1 x2 y1 y2, + option_coordinates_eq (Some (x1,y1)) (Some (x2,y2)) <-> and (eq x1 x2) (eq y1 y2). + Proof. + unfold option_coordinates_eq, option_eq, prod_eq; tauto. + Qed. + + Definition point_eq (p q : point) : Prop := prod_eq eq eq (proj1_sig p) (proj1_sig q). + Definition option_point_eq := option_eq (point_eq). + + Lemma option_point_eq_iff : forall p q, + option_point_eq (Some p) (Some q) <-> + option_coordinates_eq (Some (proj1_sig p)) (Some (proj1_sig q)). + Proof. + unfold option_point_eq, option_coordinates_eq, option_eq, point_eq; intros. + reflexivity. + Qed. + + Lemma option_coordinates_eq_dec : forall p q, + {option_coordinates_eq p q} + {~ option_coordinates_eq p q}. + Proof. + intros. + apply option_eq_dec. + apply prod_eq_dec. + apply eq_dec. + Qed. + + Lemma point_eq_dec : forall p q, {point_eq p q} + {~ point_eq p q}. + Proof. + intros. + apply prod_eq_dec. + apply eq_dec. + Qed. + + Lemma option_point_eq_dec : forall p q, + {option_point_eq p q} + {~ option_point_eq p q}. + Proof. + intros. + apply option_eq_dec. + apply point_eq_dec. + Qed. + + Lemma prod_eq_trans : forall p q r, prod_eq eq eq p q -> prod_eq eq eq q r -> + prod_eq eq eq p r. + Proof. + unfold prod_eq; intros. + repeat break_let. + intuition; etransitivity; eauto. + Qed. + + Lemma option_coordinates_eq_trans : forall p q r, option_coordinates_eq p q -> + option_coordinates_eq q r -> option_coordinates_eq p r. + Proof. + unfold option_coordinates_eq, option_eq; intros. + repeat break_match; subst; congruence || eauto using prod_eq_trans. + Qed. + + Lemma prod_eq_sym : forall p q, prod_eq eq eq p q -> prod_eq eq eq q p. + Proof. + unfold prod_eq; intros. + repeat break_let. + intuition; etransitivity; eauto. + Qed. + + Lemma option_coordinates_eq_sym : forall p q, option_coordinates_eq p q -> + option_coordinates_eq q p. + Proof. + unfold option_coordinates_eq, option_eq; intros. + repeat break_match; subst; congruence || eauto using prod_eq_sym; intuition. + Qed. + + Opaque option_coordinates_eq option_point_eq point_eq option_eq prod_eq. + + Ltac inversion_Some_eq := match goal with [H: Some ?x = Some ?y |- _] => inversion H; subst end. + + Ltac congruence_option_coord := exfalso; eauto using option_coordinates_eq_NS. + + Lemma point_dec_coordinates_onCurve : forall w p, option_coordinates_eq (point_dec_coordinates w) (Some p) -> onCurve p. + Proof. + unfold point_dec_coordinates; intros. + edestruct dec; [ | congruence_option_coord ]. + break_if; [ | congruence_option_coord]. + break_if; [ congruence_option_coord | ]. + apply E.solve_correct in e. + break_if; eapply prod_eq_onCurve; + eauto using inversion_option_coordinates_eq, solve_onCurve, solve_opp_onCurve. + Qed. + + Definition point_dec' w p : option point := + match (option_coordinates_eq_dec (point_dec_coordinates w) (Some p)) with + | left EQ => Some (exist _ p (point_dec_coordinates_onCurve w p EQ)) + | right _ => None (* this case is never reached *) + end. + + Definition point_dec (w : word (S sz)) : option point := + match point_dec_coordinates w with + | Some p => point_dec' w p + | None => None + end. + + Lemma point_coordinates_encoding_canonical : forall w p, + point_dec_coordinates w = Some p -> point_enc_coordinates p = w. + Proof. + unfold point_dec_coordinates, point_enc_coordinates; intros ? ? coord_dec_Some. + case_eq (dec (wtl w)); [ intros ? dec_Some | intros dec_None; rewrite dec_None in *; congruence ]. + destruct p. + rewrite (shatter_word w). + f_equal; rewrite dec_Some in *; + do 2 (break_if; try congruence); inversion coord_dec_Some; subst. + + destruct (eq_dec (sqrt (solve_for_x2 f1)) 0) as [sqrt_0 | ?]. + - break_if; rewrite sign_bit_zero in * by (assumption || (rewrite sqrt_0; ring)); + auto using Bool.eqb_prop. + apply F_eqb_iff in sqrt_0. + rewrite sqrt_0 in *. + destruct (whd w); inversion Heqb0; auto. + - break_if. + symmetry; auto using Bool.eqb_prop. + rewrite <- sign_bit_opp by assumption. + destruct (whd w); inversion Heqb0; break_if; auto. + + inversion coord_dec_Some; subst. + auto using encoding_canonical. + Qed. + + Lemma inversion_point_dec : forall w x, point_dec w = Some x -> + point_dec_coordinates w = Some (E.coordinates x). + Proof. + unfold point_dec, E.coordinates; intros. + break_match; [ | congruence]. + unfold point_dec' in *; break_match; [ | congruence]. + match goal with [ H : Some _ = Some _ |- _ ] => inversion H end. + reflexivity. + Qed. + + Lemma point_encoding_canonical : forall w x, point_dec w = Some x -> point_enc x = w. + Proof. + unfold point_enc; intros. + apply point_coordinates_encoding_canonical. + auto using inversion_point_dec. + Qed. + + Lemma y_decode : forall p, dec (wtl (point_enc_coordinates p)) = Some (snd p). + Proof. + intros; destruct p. cbv [point_enc_coordinates wtl snd]. + exact (encoding_valid _). + Qed. + + Lemma F_eqb_false : forall x y, x !== y -> F_eqb x y = false. + Proof. + intros; unfold F_eqb; destruct (eq_dec x y); congruence. + Qed. + + Lemma eqb_sign_opp_r : forall x y, (y !== 0) -> + Bool.eqb (sign_bit x) (sign_bit y) = false -> + Bool.eqb (sign_bit x) (sign_bit (opp y)) = true. + Proof. + intros x y y_nonzero ?. + specialize (sign_bit_opp y y_nonzero). + destruct (sign_bit x), (sign_bit y); try discriminate; + rewrite <-sign_bit_opp; auto. + Qed. + + Lemma sign_match : forall x y sqrt_y, sqrt_y !== 0 -> (x ^2) == y -> sqrt_y ^2 == y -> + Bool.eqb (sign_bit x) (sign_bit sqrt_y) = true -> + sqrt_y == x. + Proof. + intros. + pose proof (only_two_square_roots_choice x sqrt_y y) as Hchoice. + destruct Hchoice; try assumption; symmetry; try assumption. + rewrite (sign_bit_subst x (opp sqrt_y)) in * by assumption. + rewrite <-sign_bit_opp in * by assumption. + rewrite Bool.eqb_negb1 in *; discriminate. + Qed. + + Lemma point_encoding_coordinates_valid : forall p, onCurve p -> + option_coordinates_eq (point_dec_coordinates (point_enc_coordinates p)) (Some p). + Proof. + intros [x y] onCurve_p. + unfold point_dec_coordinates. + rewrite y_decode. + cbv [whd point_enc_coordinates snd]. + pose proof (square_sqrt (solve_for_x2 y) x) as solve_sqrt_ok. + forward solve_sqrt_ok. { + symmetry. + apply E.solve_correct. + assumption. + } + match goal with [ H1 : ?P, H2 : ?P -> _ |- _ ] => specialize (H2 H1); clear H1 end. + unfold sqrt_ok in solve_sqrt_ok. + break_if; [ | congruence]. + assert (solve_for_x2 y == (x ^2)) as solve_correct by (symmetry; apply E.solve_correct; assumption). + destruct (eq_dec x 0) as [eq_x_0 | neq_x_0]. + + rewrite !sign_bit_zero by + (eauto || (rewrite eq_x_0 in *; rewrite sqrt_square; [ | eauto]; reflexivity)). + rewrite Bool.andb_false_r, Bool.eqb_reflx. + apply option_coordinates_eq_iff; split; try reflexivity. + transitivity (sqrt (x ^2)); auto. + apply (sqrt_square); reflexivity. + + rewrite F_eqb_false, Bool.andb_false_l by (rewrite sqrt_square; [ | eauto]; assumption). + break_if; [ | apply eqb_sign_opp_r in Heqb]; + try (apply option_coordinates_eq_iff; split; try reflexivity); + try eapply sign_match with (y := solve_for_x2 y); eauto; + try solve [symmetry; auto]; rewrite ?square_opp; auto; + (rewrite sqrt_square; [ | eauto]); try apply Ring.opp_nonzero_nonzero; + assumption. +Qed. + +Lemma point_dec'_valid : forall p q, option_coordinates_eq (Some q) (Some (proj1_sig p)) -> + option_point_eq (point_dec' (point_enc_coordinates (proj1_sig p)) q) (Some p). +Proof. + unfold point_dec'; intros. + break_match. + + f_equal. + apply option_point_eq_iff. + destruct p as [[? ?] ?]; simpl in *. + assumption. + + exfalso; apply n. + eapply option_coordinates_eq_trans; [ | eauto using option_coordinates_eq_sym ]. + apply point_encoding_coordinates_valid. + apply (proj2_sig p). +Qed. + +Lemma point_encoding_valid : forall p, + option_point_eq (point_dec (point_enc p)) (Some p). +Proof. + intros. + unfold point_dec. + replace (point_enc p) with (point_enc_coordinates (proj1_sig p)) by reflexivity. + break_match. + + eapply (point_dec'_valid p). + rewrite <-Heqo. + apply point_encoding_coordinates_valid. + apply (proj2_sig p). + + exfalso. + eapply option_coordinates_eq_NS. + pose proof (point_encoding_coordinates_valid _ (proj2_sig p)). + rewrite Heqo in *. + eassumption. +Qed. + +End PointEncodingPre. diff --git a/src/Experiments/EdDSARefinement.v b/src/Experiments/EdDSARefinement.v index 79c9b05dd..484650934 100644 --- a/src/Experiments/EdDSARefinement.v +++ b/src/Experiments/EdDSARefinement.v @@ -1,6 +1,6 @@ Require Import Crypto.Spec.EdDSA Bedrock.Word. Require Import Coq.Classes.Morphisms. -Require Import Crypto.Algebra. Import Group. +Require Import Crypto.Algebra. Import Group ScalarMult. Require Import Util.Decidable Util.Option Util.Tactics. Require Import Omega. diff --git a/src/ModularArithmetic/BarrettReduction/Z.v b/src/ModularArithmetic/BarrettReduction/Z.v new file mode 100644 index 000000000..8b472d5d8 --- /dev/null +++ b/src/ModularArithmetic/BarrettReduction/Z.v @@ -0,0 +1,118 @@ +(*** Barrett Reduction *) +(** This file implements Barrett Reduction on [Z]. We follow Wikipedia. *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics. + +Local Open Scope Z_scope. + +Section barrett. + Context (n a : Z) + (n_reasonable : n <> 0). + (** Quoting Wikipedia <https://en.wikipedia.org/wiki/Barrett_reduction>: *) + (** In modular arithmetic, Barrett reduction is a reduction + algorithm introduced in 1986 by P.D. Barrett. A naive way of + computing *) + (** [c = a mod n] *) + (** would be to use a fast division algorithm. Barrett reduction is + an algorithm designed to optimize this operation assuming [n] is + constant, and [a < n²], replacing divisions by + multiplications. *) + + (** * General idea *) + Section general_idea. + (** Let [m = 1 / n] be the inverse of [n] as a floating point + number. Then *) + (** [a mod n = a - ⌊a m⌋ n] *) + (** where [⌊ x ⌋] denotes the floor function. The result is exact, + as long as [m] is computed with sufficient accuracy. *) + + (* [/] is [Z.div], which means truncated division *) + Local Notation "⌊am⌋" := (a / n) (only parsing). + + Theorem naive_barrett_reduction_correct + : a mod n = a - ⌊am⌋ * n. + Proof. + apply Zmod_eq_full; assumption. + Qed. + End general_idea. + + (** * Barrett algorithm *) + Section barrett_algorithm. + (** Barrett algorithm is a fixed-point analog which expresses + everything in terms of integers. Let [k] be the smallest + integer such that [2ᵏ > n]. Think of [n] as representing the + fixed-point number [n 2⁻ᵏ]. We precompute [m] such that [m = + ⌊4ᵏ / n⌋]. Then [m] represents the fixed-point number + [m 2⁻ᵏ ≈ (n 2⁻ᵏ)⁻¹]. *) + (** N.B. We don't need [k] to be the smallest such integer. *) + Context (k : Z) + (k_good : n < 2 ^ k) + (m : Z) + (m_good : m = 4^k / n). (* [/] is [Z.div], which is truncated *) + (** Wikipedia neglects to mention non-negativity, but we need it. + It might be possible to do with a relaxed assumption, such as + the sign of [a] and the sign of [n] being the same; but I + figured it wasn't worth it. *) + Context (n_pos : 0 < n) (* or just [0 <= n], since we have [n <> 0] above *) + (a_nonneg : 0 <= a). + + Lemma k_nonnegative : 0 <= k. + Proof. + destruct (Z_lt_le_dec k 0); try assumption. + rewrite !Z.pow_neg_r in * by lia; lia. + Qed. + + (** Now *) + Let q := (m * a) / 4^k. + Let r := a - q * n. + (** Because of the floor function (in Coq, because [/] means + truncated division), [q] is an integer and [r ≡ a mod n]. *) + Theorem barrett_reduction_equivalent + : r mod n = a mod n. + Proof. + subst r q m. + rewrite <- !Z.add_opp_r, !Zopp_mult_distr_l, !Z_mod_plus_full by assumption. + reflexivity. + Qed. + + Lemma qn_small + : q * n <= a. + Proof. + pose proof k_nonnegative; subst q r m. + assert (0 <= 2^(k-1)) by zero_bounds. + Z.simplify_fractions_le. + Qed. + + (** Also, if [a < n²] then [r < 2n]. *) + (** N.B. It turns out that it is sufficient to assume [a < 4ᵏ]. *) + Context (a_small : a < 4^k). + Lemma r_small : r < 2 * n. + Proof. + Hint Rewrite (Z.div_small a (4^k)) (Z.mod_small a (4^k)) using lia : zsimplify. + Hint Rewrite (Z.mul_div_eq' a n) using lia : zstrip_div. + cut (r + 1 <= 2 * n); [ lia | ]. + pose proof k_nonnegative; subst r q m. + assert (0 <= 2^(k-1)) by zero_bounds. + assert (4^k <> 0) by auto with zarith lia. + assert (a mod n < n) by auto with zarith lia. + pose proof (Z.mod_pos_bound (a * 4^k / n) (4^k)). + transitivity (a - (a * 4 ^ k / n - a) / 4 ^ k * n + 1). + { rewrite <- (Z.mul_comm a); auto 6 with zarith lia. } + rewrite (Z_div_mod_eq (_ * 4^k / n) (4^k)) by lia. + autorewrite with push_Zmul push_Zopp zsimplify zstrip_div. + break_match; auto with lia. + Qed. + + (** In that case, we have *) + Theorem barrett_reduction_small + : a mod n = if r <? n + then r + else r - n. + Proof. + pose proof r_small. pose proof qn_small. + destruct (r <? n) eqn:rlt; Z.ltb_to_lt. + { symmetry; apply (Zmod_unique a n q); subst r; lia. } + { symmetry; apply (Zmod_unique a n (q + 1)); subst r; lia. } + Qed. + End barrett_algorithm. +End barrett. diff --git a/src/ModularArithmetic/ModularBaseSystem.v b/src/ModularArithmetic/ModularBaseSystem.v index 08545bdb4..856cb1e81 100644 --- a/src/ModularArithmetic/ModularBaseSystem.v +++ b/src/ModularArithmetic/ModularBaseSystem.v @@ -8,6 +8,7 @@ Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.Tactics.VerdiTactics. Require Import Crypto.Util.Notations. +Require Import Crypto.ModularArithmetic.Pow2Base. Local Open Scope Z_scope. Section PseudoMersenneBase. @@ -19,7 +20,8 @@ Section PseudoMersenneBase. Local Notation "u ~= x" := (rep u x). Local Hint Unfold rep. - Definition encode (x : F modulus) := encode x ++ BaseSystem.zeros (length base - 1)%nat. + (* max must be greater than input; this is used to truncate last digit *) + Definition encode (x : F modulus) := encodeZ limb_widths x. (* Converts from length of extended base to length of base by reduction modulo M.*) Definition reduce (us : digits) : digits := @@ -101,15 +103,6 @@ Section Canonicalization. (* compute at compile time *) Definition modulus_digits := modulus_digits' (length base - 1). - Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := - match la with - | nil => nil - | a :: la' => match lb with - | nil => nil - | b :: lb' => f a b :: map2 f la' lb' - end - end. - Definition and_term us := if isFull us then max_ones else 0. Definition freeze us := diff --git a/src/ModularArithmetic/ModularBaseSystemOpt.v b/src/ModularArithmetic/ModularBaseSystemOpt.v index 3d7168c5a..332c0cdb2 100644 --- a/src/ModularArithmetic/ModularBaseSystemOpt.v +++ b/src/ModularArithmetic/ModularBaseSystemOpt.v @@ -27,7 +27,7 @@ Definition Z_div_opt := Eval compute in Z.div. Definition Z_pow_opt := Eval compute in Z.pow. Definition Z_opp_opt := Eval compute in Z.opp. Definition Z_shiftl_opt := Eval compute in Z.shiftl. -Definition Z_shiftl_by_opt := Eval compute in Z_shiftl_by. +Definition Z_shiftl_by_opt := Eval compute in Z.shiftl_by. Definition nth_default_opt {A} := Eval compute in @nth_default A. Definition set_nth_opt {A} := Eval compute in @set_nth A. @@ -50,10 +50,12 @@ Ltac opt_step := destruct e end. -Ltac brute_force_indices limb_widths := intros; unfold sum_firstn, limb_widths; simpl in *; +Ltac brute_force_indices limb_widths := + intros; unfold sum_firstn, limb_widths; cbv [length limb_widths] in *; repeat match goal with | _ => progress simpl in * - | _ => reflexivity + | [H : (0 + _ < _)%nat |- _ ] => simpl in H + | [H : (S _ + _ < S _)%nat |- _ ] => simpl in H | [H : (S _ < S _)%nat |- _ ] => apply lt_S_n in H | [H : (?x + _ < _)%nat |- _ ] => is_var x; destruct x | [H : (?x < _)%nat |- _ ] => is_var x; destruct x @@ -76,9 +78,10 @@ Ltac construct_params prime_modulus len k := cbv in lw; eapply Build_PseudoMersenneBaseParams with (limb_widths := lw); [ abstract (apply fold_right_and_True_forall_In_iff; simpl; repeat (split; [omega |]); auto) - | abstract (unfold limb_widths; cbv; congruence) + | abstract (cbv; congruence) | abstract brute_force_indices lw | abstract apply prime_modulus + | abstract (cbv; congruence) | abstract brute_force_indices lw]. Definition construct_mul2modulus {m} (prm : PseudoMersenneBaseParams m) : digits := @@ -471,11 +474,11 @@ Section Multiplication. cbv [BaseSystem.mul mul mul_each mul_bi mul_bi' zeros ext_base reduce]. rewrite <- mul'_opt_correct. change @base with base_opt. - rewrite map_shiftl by apply k_nonneg. + rewrite Z.map_shiftl by apply k_nonneg. rewrite c_subst. rewrite k_subst. change @map with @map_opt. - change @Z_shiftl_by with @Z_shiftl_by_opt. + change @Z.shiftl_by with @Z_shiftl_by_opt. reflexivity. Defined. @@ -600,4 +603,32 @@ Section Canonicalization. Definition freeze_opt_correct us : freeze_opt us = freeze us := proj2_sig (freeze_opt_sig us). + + Lemma freeze_opt_canonical: forall us vs x, + @pre_carry_bounds _ _ int_width us -> rep us x -> + @pre_carry_bounds _ _ int_width vs -> rep vs x -> + freeze_opt us = freeze_opt vs. + Proof. + intros. + rewrite !freeze_opt_correct. + eapply freeze_canonical with (B := int_width); eauto. + Qed. + + Lemma freeze_opt_preserves_rep : forall us x, rep us x -> + rep (freeze_opt us) x. + Proof. + intros. + rewrite freeze_opt_correct. + eapply freeze_preserves_rep; eauto. + Qed. + + Lemma freeze_opt_spec : forall us vs x, rep us x -> rep vs x -> + @pre_carry_bounds _ _ int_width us -> + @pre_carry_bounds _ _ int_width vs -> + (rep (freeze_opt us) x /\ freeze_opt us = freeze_opt vs). + Proof. + split; eauto using freeze_opt_canonical. + auto using freeze_opt_preserves_rep. + Qed. + End Canonicalization. diff --git a/src/ModularArithmetic/ModularBaseSystemProofs.v b/src/ModularArithmetic/ModularBaseSystemProofs.v index 8787c6553..43af6dee0 100644 --- a/src/ModularArithmetic/ModularBaseSystemProofs.v +++ b/src/ModularArithmetic/ModularBaseSystemProofs.v @@ -5,10 +5,15 @@ Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil Crypt Require Import VerdiTactics. Require Crypto.BaseSystem. Require Import Crypto.ModularArithmetic.ModularBaseSystem Crypto.ModularArithmetic.PrimeFieldTheorems. -Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.PseudoMersenneBaseParams Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs Crypto.ModularArithmetic.ExtendedBaseVector. +Require Import Crypto.BaseSystemProofs Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParamProofs. +Require Import Crypto.ModularArithmetic.ExtendedBaseVector. Require Import Crypto.Util.Notations. Local Open Scope Z_scope. + Section PseudoMersenneProofs. Context `{prm :PseudoMersenneBaseParams}. @@ -17,6 +22,7 @@ Section PseudoMersenneProofs. Local Notation "u .+ x" := (add u x). Local Notation "u .* x" := (ModularBaseSystem.mul u x). Local Hint Unfold rep. + Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg. Lemma rep_decode : forall us x, us ~= x -> decode us = x. Proof. @@ -34,19 +40,73 @@ Section PseudoMersenneProofs. cbv [rep]; auto. Qed. + Lemma lt_modulus_2k : modulus < 2 ^ k. + Proof. + replace (2 ^ k) with (modulus + c) by (unfold c; ring). + pose proof c_pos; omega. + Qed. Hint Resolve lt_modulus_2k. + + Lemma modulus_pos : 0 < modulus. + Proof. + pose proof (NumTheoryUtil.lt_1_p _ prime_modulus); omega. + Qed. Hint Resolve modulus_pos. + + Lemma encode'_eq : forall (x : F modulus) i, (i <= length limb_widths)%nat -> + encode' limb_widths x i = BaseSystem.encode' base x (2 ^ k) i. + Proof. + rewrite <-base_length; induction i; intros. + + rewrite encode'_zero. reflexivity. + + rewrite encode'_succ, <-IHi by omega. + simpl; do 2 f_equal. + rewrite Z.land_ones, Z.shiftr_div_pow2 by eauto. + match goal with H : (S _ <= length base)%nat |- _ => + apply le_lt_or_eq in H; destruct H end. + - repeat f_equal; unfold base in *; rewrite nth_default_base by (eauto || omega); reflexivity. + - repeat f_equal; try solve [unfold base in *; rewrite nth_default_base by (eauto || omega); reflexivity]. + rewrite nth_default_out_of_bounds by omega. + unfold k. + rewrite <-base_length; congruence. + Qed. + + Lemma encode_eq : forall x : F modulus, + encode x = BaseSystem.encode base x (2 ^ k). + Proof. + unfold encode, BaseSystem.encode; intros. + rewrite base_length; apply encode'_eq; omega. + Qed. + Lemma encode_rep : forall x : F modulus, encode x ~= x. Proof. - intros. unfold encode, rep. + intros. + rewrite encode_eq. + unfold encode, rep. split. { - unfold encode; simpl. - rewrite length_zeros. - pose proof base_length_nonzero; omega. + unfold BaseSystem.encode. + auto using encode'_length. } { unfold decode. - rewrite decode_highzeros. rewrite encode_rep. - apply ZToField_FieldToZ. - apply bv. + + apply ZToField_FieldToZ. + + apply bv. + + split; [ | etransitivity]; try (apply FieldToZ_range; auto using modulus_pos); auto. + + unfold base_max_succ_divide; intros. + match goal with H : (_ <= length base)%nat |- _ => + apply le_lt_or_eq in H; destruct H end. + - apply Z.mod_divide. + * apply nth_default_base_nonzero; auto using bv, two_k_nonzero. + * rewrite !nth_default_eq. + do 2 (erewrite nth_indep with (d := 2 ^ k) (d' := 0) by omega). + rewrite <-!nth_default_eq. + apply base_succ; eauto; omega. + - rewrite nth_default_out_of_bounds with (n := S i) by omega. + unfold base; rewrite nth_default_base by (unfold base in *; omega || eauto). + unfold k. + match goal with H : S _ = length base |- _ => + rewrite base_length in H; rewrite <-H end. + erewrite sum_firstn_succ by (apply nth_error_Some_nth_default with (x0 := 0); omega). + rewrite Z.pow_add_r by (eauto using sum_firstn_limb_widths_nonneg; + apply limb_widths_nonneg; rewrite nth_default_eq; apply nth_In; omega). + apply Z.divide_factor_r. } Qed. @@ -123,7 +183,7 @@ Section PseudoMersenneProofs. rewrite Z.sub_sub_distr, Z.sub_diag. simpl. rewrite Z.mul_comm. - rewrite mod_mult_plus; auto using modulus_nonzero. + rewrite Z.mod_add_l; auto using modulus_nonzero. rewrite <- Zplus_mod; auto. Qed. @@ -260,7 +320,7 @@ Section PseudoMersenneProofs. Proof. intros. apply Z_div_exact_2; try (apply nth_default_base_positive; omega). - apply base_succ; auto. + apply base_succ; eauto. Qed. Lemma Fdecode_decode_mod : forall us x, (length us = length base) -> @@ -271,12 +331,46 @@ Section PseudoMersenneProofs. apply FieldToZ_ZToField. Qed. + Lemma log_cap_nonneg : forall i, 0 <= log_cap i. + Proof. + unfold log_cap, nth_default; intros. + case_eq (nth_error limb_widths i); intros; try omega. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. Local Hint Resolve log_cap_nonneg. + + Definition carry_done us := forall i, (i < length base)%nat -> + 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. + + Lemma carry_done_bounds : forall us, (length us = length base) -> + (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). + Proof. + intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. + + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. + - specialize (Hcarry_done i i_lt). + split; [ intuition | ]. + destruct Hcarry_done as [Hnth_nonneg Hshiftr_0]. + apply Z.shiftr_eq_0_iff in Hshiftr_0. + destruct Hshiftr_0 as [nth_0 | []]; [ rewrite nth_0; zero_bounds | ]. + apply Z.log2_lt_pow2; auto. + - rewrite nth_default_out_of_bounds by omega. + split; zero_bounds. + + specialize (Hbounds i). + split; [ intuition | ]. + destruct Hbounds as [nth_nonneg nth_lt_pow2]. + apply Z.shiftr_eq_0_iff. + apply Z.le_lteq in nth_nonneg; destruct nth_nonneg; try solve [left; auto]. + right; split; auto. + apply Z.log2_lt_pow2; auto. + Qed. + End PseudoMersenneProofs. Section CarryProofs. Context `{prm : PseudoMersenneBaseParams}. Local Notation "u ~= x" := (rep u x). Hint Unfold log_cap. + Local Hint Resolve limb_widths_nonneg sum_firstn_limb_widths_nonneg. Lemma base_length_lt_pred : (pred (length base) < length base)%nat. Proof. @@ -284,20 +378,12 @@ Section CarryProofs. Qed. Hint Resolve base_length_lt_pred. - Lemma log_cap_nonneg : forall i, 0 <= log_cap i. - Proof. - unfold log_cap, nth_default; intros. - case_eq (nth_error limb_widths i); intros; try omega. - apply limb_widths_nonneg. - eapply nth_error_value_In; eauto. - Qed. - Lemma nth_default_base_succ : forall i, (S i < length base)%nat -> nth_default 0 base (S i) = 2 ^ log_cap i * nth_default 0 base i. Proof. intros. - repeat rewrite nth_default_base by omega. - rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg). + unfold base; repeat rewrite nth_default_base by (unfold base in *; omega || eauto). + rewrite <- Z.pow_add_r by eauto using log_cap_nonneg. destruct (NPeano.Nat.eq_dec i 0). + subst; f_equal. unfold sum_firstn, log_cap. @@ -322,8 +408,8 @@ Section CarryProofs. rewrite nth_default_base_succ by omega. rewrite Z.mul_assoc. rewrite (Z.mul_comm _ (2 ^ log_cap i)). - rewrite mul_div_eq; try ring. - apply gt_lt_symmetry. + rewrite Z.mul_div_eq; try ring. + apply Z.gt_lt_iff. apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg. Qed. @@ -351,16 +437,16 @@ Section CarryProofs. unfold log_cap. subst; rewrite length_zero, limbs_length, nth_default_nil. reflexivity. - + rewrite nth_default_base by omega. + + unfold base; rewrite nth_default_base by (unfold base in *; omega || eauto). rewrite <- Z.add_opp_l, <- Z.opp_sub_distr. unfold pow2_mod. rewrite Z.land_ones by apply log_cap_nonneg. - rewrite <- mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). + rewrite <- Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega || apply log_cap_nonneg). rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. rewrite Zopp_mult_distr_r. rewrite Z.mul_comm. rewrite Z.mul_assoc. - rewrite <- Z.pow_add_r by (apply log_cap_nonneg || apply sum_firstn_limb_widths_nonneg). + rewrite <- Z.pow_add_r by eauto using log_cap_nonneg. unfold k. replace (length limb_widths) with (S (pred (length base))) by (subst; rewrite <- base_length; apply NPeano.Nat.succ_pred; omega). @@ -369,6 +455,7 @@ Section CarryProofs. rewrite <- Zopp_mult_distr_r. rewrite Z.mul_comm. rewrite (Z.add_comm (log_cap (pred (length base)))). + unfold base. ring. Qed. @@ -453,7 +540,6 @@ End CarryProofs. Section CanonicalizationProofs. Context `{prm : PseudoMersenneBaseParams} (lt_1_length_base : (1 < length base)%nat) {B} (B_pos : 0 < B) (B_compat : forall w, In w limb_widths -> w <= B) - (c_pos : 0 < c) (* on the first reduce step, we add at most one bit of width to the first digit *) (c_reduce1 : c * (Z.ones (B - log_cap (pred (length base)))) < max_bound 0 + 1) (* on the second reduce step, we add at most one bit of width to the first digit, @@ -517,7 +603,7 @@ Section CanonicalizationProofs. Lemma max_bound_pos : forall i, (i < length base)%nat -> 0 < max_bound i. Proof. - unfold max_bound, log_cap; intros; apply Z_ones_pos_pos. + unfold max_bound, log_cap; intros; apply Z.ones_pos_pos. apply limb_widths_pos. rewrite nth_default_eq. apply nth_In. @@ -527,7 +613,7 @@ Section CanonicalizationProofs. Lemma max_bound_nonneg : forall i, 0 <= max_bound i. Proof. - unfold max_bound; intros; auto using Z_ones_nonneg. + unfold max_bound; intros; auto using Z.ones_nonneg. Qed. Local Hint Resolve max_bound_nonneg. @@ -611,9 +697,6 @@ Section CanonicalizationProofs. Qed. Local Hint Resolve pre_carry_bounds_nonzero. - Definition carry_done us := forall i, (i < length base)%nat -> - 0 <= nth_default 0 us i /\ Z.shiftr (nth_default 0 us i) (log_cap i) = 0. - (* END defs *) (* BEGIN proofs about first carry loop *) @@ -704,23 +787,6 @@ Section CanonicalizationProofs. | subst; apply pow2_mod_log_cap_small; assumption ]). Qed. - Lemma carry_done_bounds : forall us, (length us = length base) -> - (carry_done us <-> forall i, 0 <= nth_default 0 us i < 2 ^ log_cap i). - Proof. - intros ? ?; unfold carry_done; split; [ intros Hcarry_done i | intros Hbounds i i_lt ]. - + destruct (lt_dec i (length base)) as [i_lt | i_nlt]. - - specialize (Hcarry_done i i_lt). - split; [ intuition | ]. - rewrite <- max_bound_log_cap. - apply Z.lt_succ_r. - apply shiftr_eq_0_max_bound; intuition. - - rewrite nth_default_out_of_bounds; try split; try omega; auto. - + specialize (Hbounds i). - split; intuition. - apply max_bound_shiftr_eq_0; auto. - rewrite <-max_bound_log_cap in *; omega. - Qed. - Lemma carry_carry_done_done : forall i us, (length us = length base)%nat -> (i < length base)%nat -> @@ -812,7 +878,7 @@ Section CanonicalizationProofs. do 2 match goal with H : appcontext[S (pred (length base))] |- _ => erewrite <-(S_pred (length base)) in H by eauto end. unfold carry; break_if; [ unfold carry_and_reduce | omega ]. - clear_obvious. + clear_obvious. pose proof c_pos. add_set_nth; [ zero_bounds | ]; apply IHj; auto; omega. Qed. @@ -901,12 +967,12 @@ Section CanonicalizationProofs. simpl. unfold carry, carry_and_reduce; break_if; try omega. clear_obvious; add_set_nth. - split; [zero_bounds; carry_seq_lower_bound | ]. + split; [pose proof c_pos; zero_bounds; carry_seq_lower_bound | ]. rewrite Z.add_comm. apply Z.add_le_mono. + apply carry_bounds_0_upper; auto; omega. - + apply Z.mul_le_mono_pos_l; auto. - apply Z_shiftr_ones; auto; + + apply Z.mul_le_mono_pos_l; auto using c_pos. + apply Z.shiftr_ones; auto; [ | pose proof (B_compat_log_cap (pred (length base))); omega ]. split. - apply carry_bounds_lower; auto; omega. @@ -945,7 +1011,7 @@ Section CanonicalizationProofs. + rewrite <-max_bound_log_cap, <-Z.add_1_l. apply Z.add_le_mono. - rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. - apply Z_div_floor; auto. + apply Z.div_floor; auto. destruct i. * simpl. eapply Z.le_lt_trans; [ apply carry_full_bounds_0; auto | ]. @@ -970,13 +1036,13 @@ Section CanonicalizationProofs. unfold carry, carry_and_reduce; break_if; try omega. clear_obvious; add_set_nth. split. - + zero_bounds; [ | carry_seq_lower_bound]. + + pose proof c_pos; zero_bounds; [ | carry_seq_lower_bound]. apply carry_sequence_carry_full_bounds_same; auto; omega. + rewrite Z.add_comm. apply Z.add_le_mono. - apply carry_bounds_0_upper; carry_length_conditions. - etransitivity; [ | replace c with (c * 1) by ring; reflexivity ]. - apply Z.mul_le_mono_pos_l; try omega. + apply Z.mul_le_mono_pos_l; try (pose proof c_pos; omega). rewrite Z.shiftr_div_pow2 by auto. apply Z.div_le_upper_bound; auto. ring_simplify. @@ -1028,11 +1094,11 @@ Section CanonicalizationProofs. + rewrite <-max_bound_log_cap, <-Z.add_1_l. rewrite Z.shiftr_div_pow2 by apply log_cap_nonneg. apply Z.add_le_mono. - - apply Z_div_floor; auto. + - apply Z.div_floor; auto. eapply Z.le_lt_trans; [ eapply carry_full_2_bounds_0; eauto | ]. replace (Z.succ 1) with (2 ^ 1) by ring. rewrite <-max_bound_log_cap. - ring_simplify. omega. + ring_simplify. pose proof c_pos; omega. - apply carry_full_bounds; carry_length_conditions; carry_seq_lower_bound. Qed. @@ -1122,7 +1188,7 @@ Section CanonicalizationProofs. pose proof carry_full_2_bounds_0. apply Z.le_lt_trans with (m := (max_bound 0 + c) - (1 + max_bound 0)); [ apply Z.sub_le_mono_r; subst x; apply carry_full_2_bounds_0; auto; - ring_simplify | ]; omega. + ring_simplify | ]; pose proof c_pos; omega. + rewrite carry_unaffected_low by carry_length_conditions. assert (0 < S i < length base)%nat by omega. intuition; right. @@ -1148,7 +1214,7 @@ Section CanonicalizationProofs. replace (length l) with (pred (length limb_widths)) by (rewrite limb_widths_eq; auto). rewrite <- base_length. unfold carry, carry_and_reduce; break_if; try omega; intros. - add_set_nth. + add_set_nth. pose proof c_pos. split. + zero_bounds. - eapply carry_full_2_bounds_same; eauto; omega. @@ -1228,7 +1294,7 @@ Section CanonicalizationProofs. Lemma max_ones_nonneg : 0 <= max_ones. Proof. unfold max_ones. - apply Z_ones_nonneg. + apply Z.ones_nonneg. pose proof limb_widths_nonneg. induction limb_widths. cbv; congruence. @@ -1243,19 +1309,19 @@ Section CanonicalizationProofs. unfold max_ones. intros ? ? x_range. rewrite Z.land_comm. - rewrite Z.land_ones by apply Z_le_fold_right_max_initial. + rewrite Z.land_ones by apply Z.le_fold_right_max_initial. apply Z.mod_small. split; try omega. eapply Z.lt_le_trans; try eapply x_range. apply Z.pow_le_mono_r; try omega. rewrite log_cap_eq. destruct (lt_dec i (length limb_widths)). - + apply Z_le_fold_right_max. + + apply Z.le_fold_right_max. - apply limb_widths_nonneg. - rewrite nth_default_eq. auto using nth_In. + rewrite nth_default_out_of_bounds by omega. - apply Z_le_fold_right_max_initial. + apply Z.le_fold_right_max_initial. Qed. Lemma full_isFull'_true : forall j us, (length us = length base) -> @@ -1303,63 +1369,6 @@ Section CanonicalizationProofs. Qed. Local Hint Resolve carry_full_3_length. - Lemma nth_default_map2 : forall {A B C} (f : A -> B -> C) ls1 ls2 i d d1 d2, - nth_default d (map2 f ls1 ls2) i = - if lt_dec i (min (length ls1) (length ls2)) - then f (nth_default d1 ls1 i) (nth_default d2 ls2 i) - else d. - Proof. - induction ls1, ls2. - + cbv [map2 length min]. - intros. - break_if; try omega. - apply nth_default_nil. - + cbv [map2 length min]. - intros. - break_if; try omega. - apply nth_default_nil. - + cbv [map2 length min]. - intros. - break_if; try omega. - apply nth_default_nil. - + simpl. - destruct i. - - intros. rewrite !nth_default_cons. - break_if; auto; omega. - - intros. rewrite !nth_default_cons_S. - rewrite IHls1 with (d1 := d1) (d2 := d2). - repeat break_if; auto; omega. - Qed. - - Lemma map2_cons : forall A B C (f : A -> B -> C) ls1 ls2 a b, - map2 f (a :: ls1) (b :: ls2) = f a b :: map2 f ls1 ls2. - Proof. - reflexivity. - Qed. - - Lemma map2_nil_l : forall A B C (f : A -> B -> C) ls2, - map2 f nil ls2 = nil. - Proof. - reflexivity. - Qed. - - Lemma map2_nil_r : forall A B C (f : A -> B -> C) ls1, - map2 f ls1 nil = nil. - Proof. - destruct ls1; reflexivity. - Qed. - Local Hint Resolve map2_nil_r map2_nil_l. - - Opaque map2. - - Lemma map2_length : forall A B C (f : A -> B -> C) ls1 ls2, - length (map2 f ls1 ls2) = min (length ls1) (length ls2). - Proof. - induction ls1, ls2; intros; try solve [cbv; auto]. - rewrite map2_cons, !length_cons, IHls1. - auto. - Qed. - Lemma modulus_digits'_length : forall i, length (modulus_digits' i) = S i. Proof. induction i; intros; [ cbv; congruence | ]. @@ -1410,15 +1419,6 @@ Section CanonicalizationProofs. Local Hint Resolve limb_widths_nonneg. Local Hint Resolve nth_error_value_In. - (* TODO : move *) - Lemma sum_firstn_all_succ : forall n l, (length l <= n)%nat -> - sum_firstn l (S n) = sum_firstn l n. - Proof. - unfold sum_firstn; intros. - rewrite !firstn_all_strong by omega. - congruence. - Qed. - Lemma decode_carry_done_upper_bound' : forall n us, carry_done us -> (length us = length base) -> BaseSystem.decode (firstn n base) (firstn n us) < 2 ^ (sum_firstn limb_widths n). @@ -1430,7 +1430,9 @@ Section CanonicalizationProofs. destruct (nth_error_length_exists_value _ _ n_lt_length). erewrite sum_firstn_succ; eauto. rewrite Z.pow_add_r; eauto. - rewrite nth_default_base by (rewrite base_length; assumption). + unfold base. + rewrite nth_default_base by + (unfold base in *; try rewrite base_from_limb_widths_length; omega || eauto). rewrite Z.lt_add_lt_sub_r. eapply Z.lt_le_trans; eauto. rewrite Z.mul_comm at 1. @@ -1467,7 +1469,7 @@ Section CanonicalizationProofs. destruct (lt_dec n (length base)) as [ n_lt_length | ? ]. + rewrite decode_firstn_succ by auto. zero_bounds. - - rewrite nth_default_base by assumption. + - unfold base; rewrite nth_default_base by (unfold base in *; omega || eauto). apply Z.pow_nonneg; omega. - match goal with H : carry_done us |- _ => rewrite carry_done_bounds in H by auto; specialize (H n) end. intuition. @@ -1522,11 +1524,10 @@ Section CanonicalizationProofs. intros. rewrite nth_default_modulus_digits. break_if; [ | split; auto; omega]. - break_if; subst; split; auto; try rewrite <- max_bound_log_cap; omega. + break_if; subst; split; auto; try rewrite <- max_bound_log_cap; pose proof c_pos; omega. Qed. Local Hint Resolve carry_done_modulus_digits. - (* TODO : move *) Lemma decode_mod : forall us vs x, (length us = length base) -> (length vs = length base) -> decode us = x -> BaseSystem.decode base us mod modulus = BaseSystem.decode base vs mod modulus -> @@ -1538,23 +1539,6 @@ Section CanonicalizationProofs. assumption. Qed. - Ltac simpl_list_lengths := repeat match goal with - | H : appcontext[length (@nil ?A)] |- _ => rewrite (@nil_length0 A) in H - | H : appcontext[length (_ :: _)] |- _ => rewrite length_cons in H - | |- appcontext[length (@nil ?A)] => rewrite (@nil_length0 A) - | |- appcontext[length (_ :: _)] => rewrite length_cons - end. - - Lemma map2_app : forall A B C (f : A -> B -> C) ls1 ls2 ls1' ls2', - (length ls1 = length ls2) -> - map2 f (ls1 ++ ls1') (ls2 ++ ls2') = map2 f ls1 ls2 ++ map2 f ls1' ls2'. - Proof. - induction ls1, ls2; intros; rewrite ?map2_nil_r, ?app_nil_l; try congruence; - simpl_list_lengths; try omega. - rewrite <-!app_comm_cons, !map2_cons. - rewrite IHls1; auto. - Qed. - Lemma decode_map2_sub : forall us vs, (length us = length vs) -> BaseSystem.decode' base (map2 (fun x y => x - y) us vs) @@ -1582,12 +1566,12 @@ Section CanonicalizationProofs. intros z ? base_eq. rewrite decode'_cons, decode_nil, Z.add_0_r. replace z with (nth_default 0 base 0) by (rewrite base_eq; auto). - rewrite nth_default_base by omega. + unfold base; rewrite nth_default_base by (unfold base in *; omega || eauto). replace (max_bound 0 - c + 1) with (Z.succ (max_bound 0) - c) by ring. rewrite max_bound_log_cap. rewrite sum_firstn_succ with (x := log_cap 0) by (rewrite log_cap_eq; apply nth_error_Some_nth_default; rewrite <-base_length; omega). - rewrite Z.pow_add_r by auto. + rewrite Z.pow_add_r by eauto. cbv [sum_firstn fold_right firstn]. ring. + assert (S i < length base \/ S i = length base)%nat as cases by omega. @@ -1595,8 +1579,9 @@ Section CanonicalizationProofs. - rewrite sum_firstn_succ with (x := log_cap (S i)) by (rewrite log_cap_eq; apply nth_error_Some_nth_default; rewrite <-base_length; omega). - rewrite Z.pow_add_r, <-max_bound_log_cap, set_higher by auto. - rewrite IHi, modulus_digits'_length, nth_default_base by omega. + rewrite Z.pow_add_r, <-max_bound_log_cap, set_higher by eauto. + rewrite IHi, modulus_digits'_length by omega. + unfold base; rewrite nth_default_base by (unfold base in *; omega || eauto). ring. - rewrite sum_firstn_all_succ by (rewrite <-base_length; omega). rewrite decode'_splice, modulus_digits'_length, firstn_all by auto. @@ -1622,7 +1607,7 @@ Section CanonicalizationProofs. f_equal. apply land_max_ones_noop with (i := 0%nat). rewrite <-max_bound_log_cap. - omega. + pose proof c_pos; omega. + unfold modulus_digits'; fold modulus_digits'. rewrite map_app. f_equal; [ apply IHi; omega | ]. @@ -1774,7 +1759,8 @@ Section CanonicalizationProofs. + eapply Z.le_lt_trans. - eapply Z.add_le_mono with (q := nth_default 0 base n * -1); [ apply Z.le_refl | ]. apply Z.mul_le_mono_nonneg_l; try omega. - rewrite nth_default_base by omega; apply Z.pow_nonneg; omega. + unfold base; rewrite nth_default_base by (unfold base in *; omega || eauto). + zero_bounds. - ring_simplify. apply Z.lt_sub_0. apply decode_lt_next_digit; auto. @@ -1844,7 +1830,7 @@ Section CanonicalizationProofs. + match goal with |- (?a ?= ?b) = (?c ?= ?d) => rewrite (Z.compare_antisym b a); rewrite (Z.compare_antisym d c) end. apply CompOpp_inj; rewrite !CompOpp_involutive. - apply gt_lt_symmetry in Hgt. + apply Z.gt_lt_iff in Hgt. etransitivity; try apply Z_compare_decode_step_lt; auto; omega. Qed. @@ -2034,7 +2020,7 @@ Section CanonicalizationProofs. pose proof (carry_full_3_done us PCB lengths_eq) as cf3_done. rewrite carry_done_bounds in cf3_done by simpl_lengths. specialize (cf3_done 0%nat). - omega. + pose proof c_pos; omega. - assert ((0 < i <= length base - 1)%nat) as i_range by (simpl_lengths; apply lt_min_l in l; omega). specialize (high_digits i i_range). @@ -2114,4 +2100,4 @@ Section CanonicalizationProofs. eapply minimal_rep_unique; eauto; rewrite freeze_length; assumption. Qed. -End CanonicalizationProofs. +End CanonicalizationProofs.
\ No newline at end of file diff --git a/src/ModularArithmetic/Pow2Base.v b/src/ModularArithmetic/Pow2Base.v new file mode 100644 index 000000000..847967f52 --- /dev/null +++ b/src/ModularArithmetic/Pow2Base.v @@ -0,0 +1,43 @@ +Require Import Zpower ZArith. +Require Import Crypto.Util.ListUtil. +Require Import Coq.Lists.List. + +Local Open Scope Z_scope. + +Section Pow2Base. + Context (limb_widths : list Z). + Local Notation "w[ i ]" := (nth_default 0 limb_widths i). + + Fixpoint base_from_limb_widths limb_widths := + match limb_widths with + | nil => nil + | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw) + end. + + Local Notation "{base}" := (base_from_limb_widths limb_widths). + + + Definition bounded us := forall i, 0 <= nth_default 0 us i < 2 ^ w[i]. + + Definition upper_bound := 2 ^ (sum_firstn limb_widths (length limb_widths)). + + Fixpoint decode_bitwise' us i acc := + match i with + | O => acc + | S i' => decode_bitwise' us i' (Z.lor (nth_default 0 us i') (Z.shiftl acc w[i'])) + end. + + Definition decode_bitwise us := decode_bitwise' us (length us) 0. + + (* i is current index, counts down *) + Fixpoint encode' z i := + match i with + | O => nil + | S i' => let lw := sum_firstn limb_widths in + encode' z i' ++ (Z.shiftr (Z.land z (Z.ones (lw i))) (lw i')) :: nil + end. + + (* max must be greater than input; this is used to truncate last digit *) + Definition encodeZ x:= encode' x (length limb_widths). + +End Pow2Base.
\ No newline at end of file diff --git a/src/ModularArithmetic/Pow2BaseProofs.v b/src/ModularArithmetic/Pow2BaseProofs.v new file mode 100644 index 000000000..1504ca0df --- /dev/null +++ b/src/ModularArithmetic/Pow2BaseProofs.v @@ -0,0 +1,320 @@ +Require Import Zpower ZArith. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Coq.Lists.List. +Require Import Crypto.Util.ListUtil Crypto.Util.ZUtil. +Require Import Crypto.ModularArithmetic.Pow2Base Crypto.BaseSystemProofs. +Require Crypto.BaseSystem. +Local Open Scope Z_scope. + +Section Pow2BaseProofs. + Context {limb_widths} (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Notation "{base}" := (base_from_limb_widths limb_widths). + + Lemma base_from_limb_widths_length : length {base} = length limb_widths. + Proof. + induction limb_widths; try reflexivity. + simpl; rewrite map_length. + simpl in limb_widths_nonneg. + rewrite IHl; auto. + Qed. + + Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. + Proof. + unfold sum_firstn; intros. + apply fold_right_invariant; try omega. + intros y In_y_lw ? ?. + apply Z.add_nonneg_nonneg; try assumption. + apply limb_widths_nonneg. + eapply In_firstn; eauto. + Qed. Hint Resolve sum_firstn_limb_widths_nonneg. + + Lemma base_from_limb_widths_step : forall i b w, (S i < length {base})%nat -> + nth_error {base} i = Some b -> + nth_error limb_widths i = Some w -> + nth_error {base} (S i) = Some (two_p w * b). + Proof. + induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b; + unfold base_from_limb_widths in *; fold base_from_limb_widths in *; + [rewrite (@nil_length0 Z) in *; omega | ]. + simpl in *; rewrite map_length in *. + case_eq i; intros; subst. + + subst; apply nth_error_first in nth_err_w. + apply nth_error_first in nth_err_b; subst. + apply map_nth_error. + case_eq l; intros; subst; [simpl in *; omega | ]. + unfold base_from_limb_widths; fold base_from_limb_widths. + reflexivity. + + simpl in nth_err_w. + apply nth_error_map in nth_err_w. + destruct nth_err_w as [x [A B]]. + subst. + replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring. + apply map_nth_error. + apply IHl; auto. omega. + Qed. + + + Lemma nth_error_base : forall i, (i < length {base})%nat -> + nth_error {base} i = Some (two_p (sum_firstn limb_widths i)). + Proof. + induction i; intros. + + unfold sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. + intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. + + assert (i < length {base})%nat as lt_i_length by omega. + specialize (IHi lt_i_length). + rewrite base_from_limb_widths_length in lt_i_length. + destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w]. + erewrite base_from_limb_widths_step; eauto. + f_equal. + simpl. + destruct (NPeano.Nat.eq_dec i 0). + - subst; unfold sum_firstn; simpl. + apply nth_error_exists_first in nth_err_w. + destruct nth_err_w as [l' lw_destruct]; subst. + simpl; ring_simplify. + f_equal; ring. + - erewrite sum_firstn_succ; eauto. + symmetry. + apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg. + apply limb_widths_nonneg. + eapply nth_error_value_In; eauto. + Qed. + + Lemma nth_default_base : forall d i, (i < length {base})%nat -> + nth_default d {base} i = 2 ^ (sum_firstn limb_widths i). + Proof. + intros ? ? i_lt_length. + destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x]. + unfold nth_default. + rewrite nth_err_x. + rewrite nth_error_base in nth_err_x by assumption. + rewrite two_p_correct in nth_err_x. + congruence. + Qed. + + Lemma base_succ : forall i, ((S i) < length {base})%nat -> + nth_default 0 {base} (S i) mod nth_default 0 {base} i = 0. + Proof. + intros. + repeat rewrite nth_default_base by omega. + apply Z.mod_same_pow. + split; [apply sum_firstn_limb_widths_nonneg | ]. + destruct (NPeano.Nat.eq_dec i 0); subst. + + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq. + apply Z.add_nonneg_nonneg; try omega. + apply limb_widths_nonneg. + rewrite lw_eq. + apply in_eq. + + assert (i < length {base})%nat as i_lt_length by omega. + rewrite base_from_limb_widths_length in *. + apply nth_error_length_exists_value in i_lt_length. + destruct i_lt_length as [x nth_err_x]. + erewrite sum_firstn_succ; eauto. + apply nth_error_value_In in nth_err_x. + apply limb_widths_nonneg in nth_err_x. + omega. + Qed. + + Lemma nth_error_subst : forall i b, nth_error {base} i = Some b -> + b = 2 ^ (sum_firstn limb_widths i). + Proof. + intros i b nth_err_b. + pose proof (nth_error_value_length _ _ _ _ nth_err_b). + rewrite nth_error_base in nth_err_b by assumption. + rewrite two_p_correct in nth_err_b. + congruence. + Qed. + +End Pow2BaseProofs. + +Section BitwiseDecodeEncode. + Context {limb_widths} (bv : BaseSystem.BaseVector (base_from_limb_widths limb_widths)) + (limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w). + Local Hint Resolve limb_widths_nonneg. + Local Notation "w[ i ]" := (nth_default 0 limb_widths i). + Local Notation "{base}" := (base_from_limb_widths limb_widths). + Local Notation "{max}" := (upper_bound limb_widths). + + Lemma encode'_spec : forall x i, (i <= length {base})%nat -> + encode' limb_widths x i = BaseSystem.encode' {base} x {max} i. + Proof. + induction i; intros. + + rewrite encode'_zero. reflexivity. + + rewrite encode'_succ, <-IHi by omega. + simpl; do 2 f_equal. + rewrite Z.land_ones, Z.shiftr_div_pow2 by auto using sum_firstn_limb_widths_nonneg. + match goal with H : (S _ <= length {base})%nat |- _ => + apply le_lt_or_eq in H; destruct H end. + - repeat f_equal; rewrite nth_default_base by (omega || auto); reflexivity. + - repeat f_equal; try solve [rewrite nth_default_base by (omega || auto); reflexivity]. + rewrite nth_default_out_of_bounds by omega. + unfold upper_bound. + rewrite <-base_from_limb_widths_length by auto. + congruence. + Qed. + + Lemma nth_default_limb_widths_nonneg : forall i, 0 <= w[i]. + Proof. + intros; apply nth_default_preserves_properties; auto; omega. + Qed. Hint Resolve nth_default_limb_widths_nonneg. + + Lemma base_upper_bound_compatible : @base_max_succ_divide {base} {max}. + Proof. + unfold base_max_succ_divide; intros i lt_Si_length. + rewrite Nat.lt_eq_cases in lt_Si_length; destruct lt_Si_length; + rewrite !nth_default_base by (omega || auto). + + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); + rewrite <-base_from_limb_widths_length by auto; omega). + rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg. + apply Z.divide_factor_r. + + rewrite nth_default_out_of_bounds by omega. + unfold upper_bound. + replace (length limb_widths) with (S (pred (length limb_widths))) by + (rewrite base_from_limb_widths_length in H by auto; omega). + replace i with (pred (length limb_widths)) by + (rewrite base_from_limb_widths_length in H by auto; omega). + erewrite sum_firstn_succ by (eapply nth_error_Some_nth_default with (x := 0); + rewrite <-base_from_limb_widths_length by auto; omega). + rewrite Z.pow_add_r; auto using sum_firstn_limb_widths_nonneg. + apply Z.divide_factor_r. + Qed. + Hint Resolve base_upper_bound_compatible. + + Lemma encodeZ_spec : forall x, + BaseSystem.decode {base} (encodeZ limb_widths x) = x mod {max}. + Proof. + intros. + assert (length {base} = length limb_widths) by auto using base_from_limb_widths_length. + unfold encodeZ; rewrite encode'_spec by omega. + rewrite BaseSystemProofs.encode'_spec; unfold upper_bound; try zero_bounds; + auto using sum_firstn_limb_widths_nonneg. + rewrite nth_default_out_of_bounds by omega. + reflexivity. + Qed. + + Lemma decode_bitwise'_succ : forall us i acc, bounded limb_widths us -> + decode_bitwise' limb_widths us (S i) acc = + decode_bitwise' limb_widths us i (acc * (2 ^ w[i]) + nth_default 0 us i). + Proof. + intros. + simpl; f_equal. + match goal with H : bounded _ _ |- _ => + rewrite Z.lor_shiftl by (auto; unfold bounded in H; specialize (H i); assumption) end. + rewrite Z.shiftl_mul_pow2 by auto. + ring. + Qed. + + (* c is a counter, allows i to count up rather than down *) + Fixpoint partial_decode us i c := + match c with + | O => 0 + | S c' => (partial_decode us (S i) c' * 2 ^ w[i]) + nth_default 0 us i + end. + + Lemma partial_decode_counter_over : forall c us i, (c >= length us - i)%nat -> + partial_decode us i c = partial_decode us i (length us - i). + Proof. + induction c; intros. + + f_equal. omega. + + simpl. rewrite IHc by omega. + case_eq (length us - i)%nat; intros. + - rewrite nth_default_out_of_bounds with (us0 := us) by omega. + replace (length us - S i)%nat with 0%nat by omega. + reflexivity. + - simpl. repeat f_equal. omega. + Qed. + + Lemma partial_decode_counter_subst : forall c c' us i, + (c >= length us - i)%nat -> (c' >= length us - i)%nat -> + partial_decode us i c = partial_decode us i c'. + Proof. + intros. + rewrite partial_decode_counter_over by assumption. + symmetry. + auto using partial_decode_counter_over. + Qed. + + Lemma partial_decode_succ : forall c us i, (c >= length us - i)%nat -> + partial_decode us (S i) c * 2 ^ w[i] + nth_default 0 us i = + partial_decode us i c. + Proof. + intros. + rewrite partial_decode_counter_subst with (i := i) (c' := S c) by omega. + reflexivity. + Qed. + + Lemma partial_decode_intermediate : forall c us i, length us = length limb_widths -> + (c >= length us - i)%nat -> + partial_decode us i c = BaseSystem.decode' (base_from_limb_widths (skipn i limb_widths)) (skipn i us). + Proof. + induction c; intros. + + simpl. rewrite skipn_all by omega. + symmetry; apply decode_base_nil. + + simpl. + destruct (lt_dec i (length limb_widths)). + - rewrite IHc by omega. + do 2 (rewrite skipn_nth_default with (n := i) (d := 0) by (rewrite <-?base_length; omega)). + unfold base_from_limb_widths; fold base_from_limb_widths. + rewrite peel_decode. + fold (BaseSystem.mul_each (two_p w[i])). + rewrite <-mul_each_base, mul_each_rep, two_p_correct. + ring_simplify. + f_equal; ring. + - rewrite <- IHc by omega. + apply partial_decode_succ; omega. + Qed. + + + Lemma decode_bitwise'_succ_partial_decode : forall us i c, + bounded limb_widths us -> length us = length limb_widths -> + decode_bitwise' limb_widths us (S i) (partial_decode us (S i) c) = + decode_bitwise' limb_widths us i (partial_decode us i (S c)). + Proof. + intros. + rewrite decode_bitwise'_succ by auto. + f_equal. + Qed. + + Lemma decode_bitwise'_spec : forall us i, (i <= length limb_widths)%nat -> + bounded limb_widths us -> length us = length limb_widths -> + decode_bitwise' limb_widths us i (partial_decode us i (length us - i)) = + BaseSystem.decode {base} us. + Proof. + induction i; intros. + + rewrite partial_decode_intermediate by auto. + reflexivity. + + rewrite decode_bitwise'_succ_partial_decode by auto. + replace (S (length us - S i)) with (length us - i)%nat by omega. + apply IHi; auto; omega. + Qed. + + Lemma decode_bitwise_spec : forall us, bounded limb_widths us -> + length us = length limb_widths -> + decode_bitwise limb_widths us = BaseSystem.decode {base} us. + Proof. + unfold decode_bitwise; intros. + replace 0 with (partial_decode us (length us) (length us - length us)) by + (rewrite Nat.sub_diag; reflexivity). + apply decode_bitwise'_spec; auto; omega. + Qed. + +End BitwiseDecodeEncode. + +Section Conversion. + Context {limb_widthsA} (limb_widthsA_nonneg : forall w, In w limb_widthsA -> 0 <= w) + {limb_widthsB} (limb_widthsB_nonneg : forall w, In w limb_widthsB -> 0 <= w). + Local Notation "{baseA}" := (base_from_limb_widths limb_widthsA). + Local Notation "{baseB}" := (base_from_limb_widths limb_widthsB). + Context (bvB : BaseSystem.BaseVector {baseB}). + + Definition convert xs := @encodeZ limb_widthsB (@decode_bitwise limb_widthsA xs). + + Lemma convert_spec : forall xs, @bounded limb_widthsA xs -> length xs = length limb_widthsA -> + BaseSystem.decode {baseA} xs mod (@upper_bound limb_widthsB) = BaseSystem.decode {baseB} (convert xs). + Proof. + unfold convert; intros. + rewrite encodeZ_spec, decode_bitwise_spec by auto. + reflexivity. + Qed. + +End Conversion.
\ No newline at end of file diff --git a/src/ModularArithmetic/PrimeFieldTheorems.v b/src/ModularArithmetic/PrimeFieldTheorems.v index 2021e8514..a2f606f30 100644 --- a/src/ModularArithmetic/PrimeFieldTheorems.v +++ b/src/ModularArithmetic/PrimeFieldTheorems.v @@ -460,8 +460,8 @@ Section SquareRootsPrime5Mod8. apply Z2N.inj_iff; try zero_bounds. rewrite <- Z.mul_cancel_l with (p := 2) by omega. ring_simplify. - rewrite mul_div_eq by omega. - rewrite mul_div_eq by omega. + rewrite Z.mul_div_eq by omega. + rewrite Z.mul_div_eq by omega. rewrite (Zmod_div_mod 2 8 q) by (try omega; apply Zmod_divide; omega || auto). rewrite q_5mod8. diff --git a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v index 49b1875ce..c07da850f 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParamProofs.v @@ -4,151 +4,30 @@ Require Import Crypto.Util.ListUtil Crypto.Util.CaseUtil Crypto.Util.ZUtil. Require Import Crypto.ModularArithmetic.PrimeFieldTheorems. Require Import VerdiTactics. Require Import Crypto.ModularArithmetic.PseudoMersenneBaseParams. +Require Import Crypto.ModularArithmetic.Pow2Base Crypto.ModularArithmetic.Pow2BaseProofs. Require Crypto.BaseSystem. Local Open Scope Z_scope. Section PseudoMersenneBaseParamProofs. Context `{prm : PseudoMersenneBaseParams}. - Fixpoint base_from_limb_widths limb_widths := - match limb_widths with - | nil => nil - | w :: lw => 1 :: map (Z.mul (two_p w)) (base_from_limb_widths lw) - end. - - Definition base := base_from_limb_widths limb_widths. - - Lemma base_length : length base = length limb_widths. - Proof. - unfold base. - induction limb_widths; try reflexivity. - simpl; rewrite map_length; auto. - Qed. - - Lemma nth_error_first : forall {T} (a b : T) l, nth_error (a :: l) 0 = Some b -> - a = b. - Proof. - intros; simpl in *. - unfold value in *. - congruence. - Qed. - - Lemma base_from_limb_widths_step : forall i b w, (S i < length base)%nat -> - nth_error base i = Some b -> - nth_error limb_widths i = Some w -> - nth_error base (S i) = Some (two_p w * b). - Proof. - unfold base; induction limb_widths; intros ? ? ? ? nth_err_w nth_err_b; - unfold base_from_limb_widths in *; fold base_from_limb_widths in *; - [rewrite (@nil_length0 Z) in *; omega | ]. - simpl in *; rewrite map_length in *. - case_eq i; intros; subst. - + subst; apply nth_error_first in nth_err_w. - apply nth_error_first in nth_err_b; subst. - apply map_nth_error. - case_eq l; intros; subst; [simpl in *; omega | ]. - unfold base_from_limb_widths; fold base_from_limb_widths. - reflexivity. - + simpl in nth_err_w. - apply nth_error_map in nth_err_w. - destruct nth_err_w as [x [A B]]. - subst. - replace (two_p w * (two_p a * x)) with (two_p a * (two_p w * x)) by ring. - apply map_nth_error. - apply IHl; auto; omega. - Qed. - - Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x), - exists l', l = x :: l'. - Proof. - induction l; try discriminate; eexists. - apply nth_error_first in H. - subst; eauto. - Qed. - - Lemma sum_firstn_succ : forall l i x, - nth_error l i = Some x -> - sum_firstn l (S i) = x + sum_firstn l i. - Proof. - unfold sum_firstn; induction l; - [intros; rewrite (@nth_error_nil_error Z) in *; congruence | ]. - intros ? x nth_err_x; destruct (NPeano.Nat.eq_dec i 0). - + subst; simpl in *; unfold value in *. - congruence. - + rewrite <- (NPeano.Nat.succ_pred i) at 2 by auto. - rewrite <- (NPeano.Nat.succ_pred i) in nth_err_x by auto. - simpl. simpl in nth_err_x. - specialize (IHl (pred i) x). - rewrite NPeano.Nat.succ_pred in IHl by auto. - destruct (NPeano.Nat.eq_dec (pred i) 0). - - replace i with 1%nat in * by omega. - simpl. replace (pred 1) with 0%nat in * by auto. - apply nth_error_exists_first in nth_err_x. - destruct nth_err_x as [l' ?]. - subst; simpl; ring. - - rewrite IHl by auto; ring. - Qed. - Lemma limb_widths_nonneg : forall w, In w limb_widths -> 0 <= w. Proof. intros. apply Z.lt_le_incl. auto using limb_widths_pos. - Qed. - - Lemma sum_firstn_limb_widths_nonneg : forall n, 0 <= sum_firstn limb_widths n. - Proof. - unfold sum_firstn; intros. - apply fold_right_invariant; try omega. - intros y In_y_lw ? ?. - apply Z.add_nonneg_nonneg; try assumption. - apply limb_widths_nonneg. - eapply In_firstn; eauto. - Qed. + Qed. Hint Resolve limb_widths_nonneg. Lemma k_nonneg : 0 <= k. Proof. - apply sum_firstn_limb_widths_nonneg. - Qed. + apply sum_firstn_limb_widths_nonneg; auto. + Qed. Hint Resolve k_nonneg. - Lemma nth_error_base : forall i, (i < length base)%nat -> - nth_error base i = Some (two_p (sum_firstn limb_widths i)). - Proof. - induction i; intros. - + unfold base, sum_firstn, base_from_limb_widths in *; case_eq limb_widths; try reflexivity. - intro lw_nil; rewrite lw_nil, (@nil_length0 Z) in *; omega. - + - assert (i < length base)%nat as lt_i_length by omega. - specialize (IHi lt_i_length). - rewrite base_length in lt_i_length. - destruct (nth_error_length_exists_value _ _ lt_i_length) as [w nth_err_w]. - erewrite base_from_limb_widths_step; eauto. - f_equal. - simpl. - destruct (NPeano.Nat.eq_dec i 0). - - subst; unfold sum_firstn; simpl. - apply nth_error_exists_first in nth_err_w. - destruct nth_err_w as [l' lw_destruct]; subst. - rewrite lw_destruct. - ring_simplify. - f_equal; simpl; ring. - - erewrite sum_firstn_succ; eauto. - symmetry. - apply two_p_is_exp; auto using sum_firstn_limb_widths_nonneg. - apply limb_widths_nonneg. - eapply nth_error_value_In; eauto. - Qed. + Definition base := base_from_limb_widths limb_widths. - Lemma nth_default_base : forall d i, (i < length base)%nat -> - nth_default d base i = 2 ^ (sum_firstn limb_widths i). + Lemma base_length : length base = length limb_widths. Proof. - intros ? ? i_lt_length. - destruct (nth_error_length_exists_value _ _ i_lt_length) as [x nth_err_x]. - unfold nth_default. - rewrite nth_err_x. - rewrite nth_error_base in nth_err_x by assumption. - rewrite two_p_correct in nth_err_x. - congruence. + unfold base; auto using base_from_limb_widths_length. Qed. Lemma base_matches_modulus: forall i j, @@ -163,63 +42,31 @@ Section PseudoMersenneBaseParamProofs. rewrite (Z.mul_comm r). subst r. assert (i + j - length base < length base)%nat by omega. - rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.mul_pos_pos; - [ | subst b; rewrite nth_default_base; try assumption ]; - apply Z.pow_pos_nonneg; omega || apply k_nonneg || apply sum_firstn_limb_widths_nonneg). + rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; apply Z.mul_pos_pos; + [ | subst b; unfold base; rewrite nth_default_base; try assumption ]; + zero_bounds; auto using sum_firstn_limb_widths_nonneg, limb_widths_nonneg). rewrite (Zminus_0_l_reverse (b i * b j)) at 1. f_equal. subst b. - repeat rewrite nth_default_base by assumption. - do 2 rewrite <- Z.pow_add_r by (apply sum_firstn_limb_widths_nonneg || apply k_nonneg). + unfold base; repeat rewrite nth_default_base by auto. + do 2 rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg. symmetry. - apply mod_same_pow. + apply Z.mod_same_pow. split. - + apply Z.add_nonneg_nonneg; apply sum_firstn_limb_widths_nonneg || apply k_nonneg. - + rewrite base_length in *; apply limb_widths_match_modulus; assumption. + + apply Z.add_nonneg_nonneg; auto using sum_firstn_limb_widths_nonneg. + + rewrite base_length, base_from_limb_widths_length in * by auto. + apply limb_widths_match_modulus; auto. Qed. - Lemma base_succ : forall i, ((S i) < length base)%nat -> - nth_default 0 base (S i) mod nth_default 0 base i = 0. - Proof. - intros. - repeat rewrite nth_default_base by omega. - apply mod_same_pow. - split; [apply sum_firstn_limb_widths_nonneg | ]. - destruct (NPeano.Nat.eq_dec i 0); subst. - + case_eq limb_widths; intro; unfold sum_firstn; simpl; try omega; intros l' lw_eq. - apply Z.add_nonneg_nonneg; try omega. - apply limb_widths_nonneg. - rewrite lw_eq. - apply in_eq. - + assert (i < length base)%nat as i_lt_length by omega. - rewrite base_length in *. - apply nth_error_length_exists_value in i_lt_length. - destruct i_lt_length as [x nth_err_x]. - erewrite sum_firstn_succ; eauto. - apply nth_error_value_In in nth_err_x. - apply limb_widths_nonneg in nth_err_x. - omega. - Qed. - - Lemma nth_error_subst : forall i b, nth_error base i = Some b -> - b = 2 ^ (sum_firstn limb_widths i). - Proof. - intros i b nth_err_b. - pose proof (nth_error_value_length _ _ _ _ nth_err_b). - rewrite nth_error_base in nth_err_b by assumption. - rewrite two_p_correct in nth_err_b. - congruence. - Qed. - Lemma base_positive : forall b : Z, In b base -> b > 0. Proof. intros b In_b_base. apply In_nth_error_value in In_b_base. destruct In_b_base as [i nth_err_b]. - apply nth_error_subst in nth_err_b. + apply nth_error_subst in nth_err_b; [ | auto ]. rewrite nth_err_b. - apply gt_lt_symmetry. - apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg. + apply Z.gt_lt_iff. + apply Z.pow_pos_nonneg; omega || auto using sum_firstn_limb_widths_nonneg. Qed. Lemma b0_1 : forall x : Z, nth_default x base 0 = 1. @@ -234,12 +81,14 @@ Section PseudoMersenneBaseParamProofs. b i * b j = r * b (i + j)%nat. Proof. intros; subst b r. - repeat rewrite nth_default_base by omega. + unfold base in *. + repeat rewrite nth_default_base by (omega || auto). rewrite (Z.mul_comm _ (2 ^ (sum_firstn limb_widths (i+j)))). - rewrite mul_div_eq by (apply gt_lt_symmetry; apply Z.pow_pos_nonneg; omega || apply sum_firstn_limb_widths_nonneg). - rewrite <- Z.pow_add_r by apply sum_firstn_limb_widths_nonneg. - rewrite mod_same_pow; try ring. - split; [ apply sum_firstn_limb_widths_nonneg | ]. + rewrite Z.mul_div_eq by (apply Z.gt_lt_iff; zero_bounds; + auto using sum_firstn_limb_widths_nonneg). + rewrite <- Z.pow_add_r by auto using sum_firstn_limb_widths_nonneg. + rewrite Z.mod_same_pow; try ring. + split; [ auto using sum_firstn_limb_widths_nonneg | ]. apply limb_widths_good. rewrite <- base_length; assumption. Qed. @@ -250,4 +99,4 @@ Section PseudoMersenneBaseParamProofs. base_good := base_good }. -End PseudoMersenneBaseParamProofs. +End PseudoMersenneBaseParamProofs.
\ No newline at end of file diff --git a/src/ModularArithmetic/PseudoMersenneBaseParams.v b/src/ModularArithmetic/PseudoMersenneBaseParams.v index e20a7ed09..1f9a0f2f6 100644 --- a/src/ModularArithmetic/PseudoMersenneBaseParams.v +++ b/src/ModularArithmetic/PseudoMersenneBaseParams.v @@ -1,10 +1,9 @@ Require Import ZArith. Require Import List. +Require Import Crypto.Util.ListUtil. Require Crypto.BaseSystem. Local Open Scope Z_scope. -Definition sum_firstn l n := fold_right Z.add 0 (firstn n l). - Class PseudoMersenneBaseParams (modulus : Z) := { limb_widths : list Z; limb_widths_pos : forall w, In w limb_widths -> 0 < w; @@ -15,6 +14,7 @@ Class PseudoMersenneBaseParams (modulus : Z) := { prime_modulus : Znumtheory.prime modulus; k := sum_firstn limb_widths (length limb_widths); c := 2 ^ k - modulus; + c_pos : 0 < c; limb_widths_match_modulus : forall i j, (i < length limb_widths)%nat -> (j < length limb_widths)%nat -> diff --git a/src/Spec/CompleteEdwardsCurve.v b/src/Spec/CompleteEdwardsCurve.v index f6db1c14f..16a1217ce 100644 --- a/src/Spec/CompleteEdwardsCurve.v +++ b/src/Spec/CompleteEdwardsCurve.v @@ -29,8 +29,9 @@ Module E. Definition coordinates (P:point) : (F*F) := proj1_sig P. (** The following points are indeed on the curve -- see [CompleteEdwardsCurve.Pre] for proof *) - Local Obligation Tactic := intros; apply Pre.zeroOnCurve - || apply (Pre.unifiedAdd'_onCurve (char_gt_2:=char_gt_2) (d_nonsquare:=nonsquare_d) + Local Obligation Tactic := intros; + apply (Pre.zeroOnCurve(a_nonzero:=nonzero_a)(char_gt_2:=char_gt_2)) || + apply (Pre.unifiedAdd'_onCurve (char_gt_2:=char_gt_2) (d_nonsquare:=nonsquare_d) (a_nonzero:=nonzero_a) (a_square:=square_a) _ _ (proj2_sig _) (proj2_sig _)). Program Definition zero : point := (0, 1). diff --git a/src/Spec/EdDSA.v b/src/Spec/EdDSA.v index 03a723e10..25109bc4c 100644 --- a/src/Spec/EdDSA.v +++ b/src/Spec/EdDSA.v @@ -37,7 +37,7 @@ Section EdDSA. := { EdDSA_group:@Algebra.group E Eeq Eadd Ezero Eopp; - EdDSA_scalarmult:@Algebra.Group.is_scalarmult E Eeq Eadd Ezero EscalarMult; + EdDSA_scalarmult:@Algebra.ScalarMult.is_scalarmult E Eeq Eadd Ezero EscalarMult; EdDSA_c_valid : c = 2 \/ c = 3; diff --git a/src/Spec/WeierstrassCurve.v b/src/Spec/WeierstrassCurve.v new file mode 100644 index 000000000..7ec5d99ec --- /dev/null +++ b/src/Spec/WeierstrassCurve.v @@ -0,0 +1,84 @@ +Require Crypto.WeierstrassCurve.Pre. + +Module E. + Section WeierstrassCurves. + (* Short Weierstrass curves with addition laws. References: + * <https://hyperelliptic.org/EFD/g1p/auto-shortw.html> + * <https://cr.yp.to/talks/2007.06.07/slides.pdf> + * See also: + * <http://cs.ucsb.edu/~koc/ccs130h/2013/EllipticHyperelliptic-CohenFrey.pdf> (page 79) + *) + + Context {F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv} `{Algebra.field F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv}. + Local Infix "=" := Feq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Infix "=?" := Algebra.eq_dec (at level 70, no associativity) : type_scope. + Local Notation "x =? y" := (Sumbool.bool_of_sumbool (Algebra.eq_dec x y)) : bool_scope. + Local Infix "+" := Fadd. Local Infix "*" := Fmul. + Local Infix "-" := Fsub. Local Infix "/" := Fdiv. + Local Notation "- x" := (Fopp x). + Local Notation "x ^ 2" := (x*x) (at level 30). Local Notation "x ^ 3" := (x*x^2) (at level 30). + Local Notation "'∞'" := unit : type_scope. + Local Notation "'∞'" := (inr tt) : core_scope. + Local Notation "0" := Fzero. Local Notation "1" := Fone. + Local Notation "2" := (1+1). Local Notation "3" := (1+2). Local Notation "4" := (1+3). + Local Notation "8" := (1+(1+(1+(1+4)))). Local Notation "12" := (1+(1+(1+(1+8)))). + Local Notation "16" := (1+(1+(1+(1+12)))). Local Notation "20" := (1+(1+(1+(1+16)))). + Local Notation "24" := (1+(1+(1+(1+20)))). Local Notation "27" := (1+(1+(1+24))). + + Local Notation "( x , y )" := (inl (pair x y)). + Local Open Scope core_scope. + + Context {a b: F}. + + (** N.B. We may require more conditions to prove that points form + a group under addition (associativity, in particular. If + that's the case, more fields will be added to this class. *) + Class weierstrass_params := + { + char_gt_2 : 2 <> 0; + char_ne_3 : 3 <> 0; + nonzero_discriminant : -(16) * (4 * a^3 + 27 * b^2) <> 0 + }. + Context `{weierstrass_params}. + + Definition point := { P | match P with + | (x, y) => y^2 = x^3 + a*x + b + | ∞ => True + end }. + Definition coordinates (P:point) : (F*F + ∞) := proj1_sig P. + + (** The following points are indeed on the curve -- see [WeierstrassCurve.Pre] for proof *) + Local Obligation Tactic := + try solve [ Program.Tactics.program_simpl + | intros; apply (Pre.unifiedAdd'_onCurve _ _ (proj2_sig _) (proj2_sig _)) ]. + + Program Definition zero : point := ∞. + + Program Definition add (P1 P2:point) : point + := exist + _ + (match coordinates P1, coordinates P2 return _ with + | (x1, y1), (x2, y2) => + if x1 =? x2 then + if y2 =? -y1 then ∞ + else ((3*x1^2+a)^2 / (2*y1)^2 - x1 - x1, + (2*x1+x1)*(3*x1^2+a) / (2*y1) - (3*x1^2+a)^3/(2*y1)^3-y1) + else ((y2-y1)^2 / (x2-x1)^2 - x1 - x2, + (2*x1+x2)*(y2-y1) / (x2-x1) - (y2-y1)^3 / (x2-x1)^3 - y1) + | ∞, ∞ => ∞ + | ∞, _ => coordinates P2 + | _, ∞ => coordinates P1 + end) + _. + + Fixpoint mul (n:nat) (P : point) : point := + match n with + | O => zero + | S n' => add P (mul n' P) + end. + End WeierstrassCurves. +End E. + +Delimit Scope E_scope with E. +Infix "+" := E.add : E_scope. +Infix "*" := E.mul : E_scope. diff --git a/src/Tactics/Nsatz.v b/src/Tactics/Algebra_syntax/Nsatz.v index 04f35c200..a5b04cfa2 100644 --- a/src/Tactics/Nsatz.v +++ b/src/Tactics/Algebra_syntax/Nsatz.v @@ -121,11 +121,14 @@ Ltac nsatz_sugar_power sugar power := let domain := nsatz_guess_domain in nsatz_domain_sugar_power domain sugar power. -Tactic Notation "nsatz" constr(n) := - let nn := (eval compute in (BinNat.N.of_nat n)) in - nsatz_sugar_power BinInt.Z0 nn. +Ltac nsatz_power power := + let power_N := (eval compute in (BinNat.N.of_nat power)) in + nsatz_sugar_power BinInt.Z0 power_N. -Tactic Notation "nsatz" := nsatz 1%nat || nsatz 2%nat || nsatz 3%nat || nsatz 4%nat || nsatz 5%nat. +Ltac nsatz := nsatz_power 1%nat || nsatz_power 2%nat || nsatz_power 3%nat || nsatz_power 4%nat || nsatz_power 5%nat. + +Tactic Notation "nsatz" := nsatz. +Tactic Notation "nsatz" constr(n) := nsatz_power n. (** If the goal is of the form [?x <> ?y] and assuming [?x = ?y] contradicts any hypothesis of the form [?x' <> ?y'], we turn this diff --git a/src/Testbit.v b/src/Testbit.v index 2bfcc3df6..f9de9092b 100644 --- a/src/Testbit.v +++ b/src/Testbit.v @@ -107,7 +107,7 @@ Proof. rewrite <- nth_default_eq in uniform. erewrite nth_error_value_eq_nth_default in uniform; eauto. subst. - destruct r; [ | apply pos_pow_nat_pos | pose proof (Zlt_neg_0 p) ] ; omega. + destruct r; [ | apply Z.pos_pow_nat_pos | pose proof (Zlt_neg_0 p) ] ; omega. + intros. rewrite nth_default_eq. rewrite uniform; auto. @@ -151,7 +151,7 @@ Proof. induction us; boring. rewrite <- (IHus base) by (omega || eauto using no_overflow_tail). rewrite decode_cons by (eapply uniform_base_BaseVector; eauto; - rewrite gt_lt_symmetry; apply Z_pow_gt0; omega). + rewrite Z.gt_lt_iff; apply Z.pow_pos_nonneg; omega). simpl. f_equal. + symmetry. eapply no_overflow_cons; eauto. @@ -174,14 +174,15 @@ Proof. auto using Z.land_0_l. + destruct i; simpl. - rewrite nth_default_cons. - rewrite Z.shiftr_0_r, Z_land_add_land by omega. + rewrite Z.shiftr_0_r, Z.land_add_land by omega. symmetry; eapply no_overflow_cons; eauto. - rewrite nth_default_cons_S. erewrite IHus; eauto using no_overflow_tail. remember (i * limb_width)%nat as k. - rewrite Z_shiftr_add_land by omega. - replace (limb_width + k - limb_width)%nat with k by omega. - reflexivity. + rewrite Z.shiftr_add_shiftl_high; rewrite ?Nat2Z.inj_add; + repeat f_equal; try omega. + rewrite Z.land_ones by apply Nat2Z.is_nonneg. + apply Z.mod_pos_bound. zero_bounds. Qed. Lemma unfold_bits_testbit : forall limb_width us n, (0 < limb_width)%nat -> @@ -190,7 +191,7 @@ Lemma unfold_bits_testbit : forall limb_width us n, (0 < limb_width)%nat -> Proof. unfold testbit; intros. erewrite unfold_bits_indexing; eauto. - rewrite <- Z_testbit_low by + rewrite <- Z.testbit_low by (split; try apply Nat2Z.inj_lt; pose proof (mod_bound_pos n limb_width); omega). rewrite Z.shiftr_spec by apply Nat2Z.is_nonneg. f_equal. diff --git a/src/Util/AdditionChainExponentiation.v b/src/Util/AdditionChainExponentiation.v new file mode 100644 index 000000000..ca1394115 --- /dev/null +++ b/src/Util/AdditionChainExponentiation.v @@ -0,0 +1,102 @@ +Require Import Coq.Lists.List Coq.Lists.SetoidList. Import ListNotations. +Require Import Crypto.Util.ListUtil. +Require Import Algebra. Import Monoid ScalarMult. +Require Import VerdiTactics. +Require Import Crypto.Util.Option. + +Section AddChainExp. + Function add_chain (is:list (nat*nat)) : list nat := + match is with + | nil => nil + | (i,j)::is' => + let chain' := add_chain is' in + nth_default 1 chain' i + nth_default 1 chain' j::chain' + end. + +Example wikipedia_addition_chain : add_chain (rev [ +(0, 0); (* 2 = 1 + 1 *) (* the indices say how far back the chain to look *) +(0, 1); (* 3 = 2 + 1 *) +(0, 0); (* 6 = 3 + 3 *) +(0, 0); (* 12 = 6 + 6 *) +(0, 0); (* 24 = 12 + 12 *) +(0, 2); (* 30 = 24 + 6 *) +(0, 6)] (* 31 = 30 + 1 *) +) = [31; 30; 24; 12; 6; 3; 2]. reflexivity. Qed. + + Context {G eq op id} {monoid:@Algebra.monoid G eq op id}. + Local Infix "=" := eq : type_scope. + + Function add_chain_exp (is : list (nat*nat)) (x : G) : list G := + match is with + | nil => nil + | (i,j)::is' => + let chain' := add_chain_exp is' x in + op (nth_default x chain' i) (nth_default x chain' j) ::chain' + end. + + Fixpoint scalarmult n (x : G) : G := match n with + | O => id + | S n' => op x (scalarmult n' x) + end. + + Lemma add_chain_exp_step : forall i j is x, + (forall n, nth_default x (add_chain_exp is x) n = scalarmult (nth_default 1 (add_chain is) n) x) -> + (eqlistA eq) + (add_chain_exp ((i,j) :: is) x) + (op (scalarmult (nth_default 1 (add_chain is) i) x) + (scalarmult (nth_default 1 (add_chain is) j) x) :: add_chain_exp is x). + Proof. + intros. + unfold add_chain_exp; fold add_chain_exp. + apply eqlistA_cons; [ | reflexivity]. + f_equiv; auto. + Qed. + + Lemma scalarmult_same : forall c x y, eq x y -> eq (scalarmult c x) (scalarmult c y). + Proof. + induction c; intros. + + reflexivity. + + simpl. f_equiv; auto. + Qed. + + Lemma scalarmult_pow_add : forall a b x, scalarmult (a + b) x = op (scalarmult a x) (scalarmult b x). + Proof. + intros; eapply scalarmult_add_l. + Grab Existential Variables. + 2:eauto. + econstructor; try reflexivity. + repeat intro; subst. + auto using scalarmult_same. + Qed. + + Lemma add_chain_exp_spec : forall is x, + (forall n, nth_default x (add_chain_exp is x) n = scalarmult (nth_default 1 (add_chain is) n) x). + Proof. + induction is; intros. + + simpl; rewrite !nth_default_nil. cbv. + symmetry; apply right_identity. + + destruct a. + rewrite add_chain_exp_step by auto. + unfold add_chain; fold add_chain. + destruct n. + - rewrite !nth_default_cons, scalarmult_pow_add. reflexivity. + - rewrite !nth_default_cons_S; auto. + Qed. + + Lemma add_chain_exp_answer : forall is x n, Logic.eq (head (add_chain is)) (Some n) -> + option_eq eq (Some (scalarmult n x)) (head (add_chain_exp is x)). + Proof. + intros. + change head with (fun {T} (xs : list T) => nth_error xs 0) in *. + cbv beta in *. + cbv [option_eq]. + destruct is; [ discriminate | ]. + destruct p. + simpl in *. + injection H; clear H; intro H. + subst n. + rewrite !add_chain_exp_spec. + apply scalarmult_pow_add. + Qed. + +End AddChainExp.
\ No newline at end of file diff --git a/src/Util/ListUtil.v b/src/Util/ListUtil.v index 0426c0834..169564c23 100644 --- a/src/Util/ListUtil.v +++ b/src/Util/ListUtil.v @@ -1,22 +1,112 @@ Require Import Coq.Lists.List. Require Import Coq.omega.Omega. Require Import Coq.Arith.Peano_dec. +Require Import Coq.Classes.Morphisms. Require Import Crypto.Tactics.VerdiTactics. +Require Import Coq.Numbers.Natural.Peano.NPeano. +Require Import Crypto.Util.NatUtil. + +Create HintDb distr_length discriminated. +Create HintDb simpl_set_nth discriminated. +Create HintDb simpl_update_nth discriminated. +Create HintDb simpl_nth_default discriminated. +Create HintDb simpl_nth_error discriminated. +Create HintDb simpl_firstn discriminated. +Create HintDb simpl_skipn discriminated. +Create HintDb simpl_fold_right discriminated. +Create HintDb simpl_sum_firstn discriminated. +Create HintDb pull_nth_error discriminated. +Create HintDb push_nth_error discriminated. +Create HintDb pull_nth_default discriminated. +Create HintDb push_nth_default discriminated. +Create HintDb pull_firstn discriminated. +Create HintDb push_firstn discriminated. +Create HintDb pull_update_nth discriminated. +Create HintDb push_update_nth discriminated. + +Hint Rewrite + @app_length + @rev_length + @map_length + @seq_length + @fold_left_length + @split_length_l + @split_length_r + @firstn_length + @combine_length + @prod_length + : distr_length. + +Definition sum_firstn l n := fold_right Z.add 0%Z (firstn n l). + +Fixpoint map2 {A B C} (f : A -> B -> C) (la : list A) (lb : list B) : list C := + match la with + | nil => nil + | a :: la' => match lb with + | nil => nil + | b :: lb' => f a b :: map2 f la' lb' + end + end. + +(* xs[n] := f xs[n] *) +Fixpoint update_nth {T} n f (xs:list T) {struct n} := + match n with + | O => match xs with + | nil => nil + | x'::xs' => f x'::xs' + end + | S n' => match xs with + | nil => nil + | x'::xs' => x'::update_nth n' f xs' + end + end. + +(* xs[n] := x *) +Definition set_nth {T} n x (xs:list T) + := update_nth n (fun _ => x) xs. + +Definition splice_nth {T} n (x:T) xs := firstn n xs ++ x :: skipn (S n) xs. +Hint Unfold splice_nth. Ltac boring := simpl; intuition; repeat match goal with | [ H : _ |- _ ] => rewrite H; clear H | _ => progress autounfold in * - | _ => progress try autorewrite with core + | _ => progress autorewrite with core | _ => progress simpl in * | _ => progress intuition end; eauto. +Ltac boring_list := + repeat match goal with + | _ => progress boring + | _ => progress autorewrite with distr_length simpl_nth_default simpl_update_nth simpl_set_nth simpl_nth_error in * + end. + +Lemma nth_default_cons : forall {T} (x u0 : T) us, nth_default x (u0 :: us) 0 = u0. +Proof. auto. Qed. + +Hint Rewrite @nth_default_cons : simpl_nth_default. +Hint Rewrite @nth_default_cons : push_nth_default. + +Lemma nth_default_cons_S : forall {A} us (u0 : A) n d, + nth_default d (u0 :: us) (S n) = nth_default d us n. +Proof. boring. Qed. + +Hint Rewrite @nth_default_cons_S : simpl_nth_default. +Hint Rewrite @nth_default_cons_S : push_nth_default. + +Lemma nth_default_nil : forall {T} n (d : T), nth_default d nil n = d. +Proof. induction n; boring. Qed. + +Hint Rewrite @nth_default_nil : simpl_nth_default. +Hint Rewrite @nth_default_nil : push_nth_default. + Lemma nth_error_nil_error : forall {A} n, nth_error (@nil A) n = None. -Proof. -intros. induction n; boring. -Qed. +Proof. induction n; boring. Qed. + +Hint Rewrite @nth_error_nil_error : simpl_nth_error. Ltac nth_tac' := intros; simpl in *; unfold error,value in *; repeat progress (match goal with @@ -69,6 +159,7 @@ Proof. induction i; destruct xs; nth_tac'; rewrite IHi by omega; auto. Qed. Hint Resolve nth_error_length_error. +Hint Rewrite @nth_error_length_error using omega : simpl_nth_error. Lemma map_nth_default : forall (A B : Type) (f : A -> B) n x y l, (n < length l) -> nth_default y (map f l) n = f (nth_default x l n). @@ -82,6 +173,8 @@ Proof. omega. Qed. +Hint Rewrite @map_nth_default using omega : push_nth_default. + Ltac nth_tac := repeat progress (try nth_tac'; try (match goal with | [ H: nth_error (map _ _) _ = Some _ |- _ ] => destruct (nth_error_map _ _ _ _ _ _ H); clear H @@ -90,46 +183,139 @@ Ltac nth_tac := end)). Lemma app_cons_app_app : forall T xs (y:T) ys, xs ++ y :: ys = (xs ++ (y::nil)) ++ ys. +Proof. induction xs; boring. Qed. + +Lemma unfold_set_nth {T} n x + : forall xs, + @set_nth T n x xs + = match n with + | O => match xs with + | nil => nil + | x'::xs' => x::xs' + end + | S n' => match xs with + | nil => nil + | x'::xs' => x'::set_nth n' x xs' + end + end. Proof. - induction xs; boring. + induction n; destruct xs; reflexivity. Qed. -(* xs[n] := x *) -Fixpoint set_nth {T} n x (xs:list T) {struct n} := - match n with - | O => match xs with - | nil => nil - | x'::xs' => x::xs' - end - | S n' => match xs with - | nil => nil - | x'::xs' => x'::set_nth n' x xs' - end - end. +Lemma simpl_set_nth_0 {T} x + : forall xs, + @set_nth T 0 x xs + = match xs with + | nil => nil + | x'::xs' => x::xs' + end. +Proof. intro; rewrite unfold_set_nth; reflexivity. Qed. -Lemma nth_set_nth : forall m {T} (xs:list T) (n:nat) (x x':T), - nth_error (set_nth m x xs) n = +Lemma simpl_set_nth_S {T} x n + : forall xs, + @set_nth T (S n) x xs + = match xs with + | nil => nil + | x'::xs' => x'::set_nth n x xs' + end. +Proof. intro; rewrite unfold_set_nth; reflexivity. Qed. + +Hint Rewrite @simpl_set_nth_S @simpl_set_nth_0 : simpl_set_nth. + +Lemma update_nth_ext {T} f g n + : forall xs, (forall x, nth_error xs n = Some x -> f x = g x) + -> @update_nth T n f xs = @update_nth T n g xs. +Proof. + induction n; destruct xs; simpl; intros H; + try rewrite IHn; try rewrite H; + try congruence; trivial. +Qed. + +Global Instance update_nth_Proper {T} + : Proper (eq ==> pointwise_relation _ eq ==> eq ==> eq) (@update_nth T). +Proof. repeat intro; subst; apply update_nth_ext; trivial. Qed. + +Lemma update_nth_id_eq_specific {T} f n + : forall (xs : list T) (H : forall x, nth_error xs n = Some x -> f x = x), + update_nth n f xs = xs. +Proof. + induction n; destruct xs; simpl; intros; + try rewrite IHn; try rewrite H; unfold value in *; + try congruence; assumption. +Qed. + +Hint Rewrite @update_nth_id_eq_specific using congruence : simpl_update_nth. + +Lemma update_nth_id_eq : forall {T} f (H : forall x, f x = x) n (xs : list T), + update_nth n f xs = xs. +Proof. intros; apply update_nth_id_eq_specific; trivial. Qed. + +Hint Rewrite @update_nth_id_eq using congruence : simpl_update_nth. + +Lemma update_nth_id : forall {T} n (xs : list T), + update_nth n (fun x => x) xs = xs. +Proof. intros; apply update_nth_id_eq; trivial. Qed. + +Hint Rewrite @update_nth_id : simpl_update_nth. + +Lemma nth_update_nth : forall m {T} (xs:list T) (n:nat) (f:T -> T), + nth_error (update_nth m f xs) n = if eq_nat_dec n m - then (if lt_dec n (length xs) then Some x else None) + then option_map f (nth_error xs n) else nth_error xs n. Proof. - induction m. + induction m. + { destruct n, xs; auto. } + { destruct xs, n; intros; simpl; auto; + [ | rewrite IHm ]; clear IHm; + edestruct eq_nat_dec; reflexivity. } +Qed. - destruct n, xs; auto. +Hint Rewrite @nth_update_nth : push_nth_error. +Hint Rewrite <- @nth_update_nth : pull_nth_error. - intros; destruct xs, n; auto. - simpl; unfold error; match goal with - [ |- None = if ?x then None else None ] => destruct x - end; auto. +Lemma length_update_nth : forall {T} i f (xs:list T), length (update_nth i f xs) = length xs. +Proof. + induction i, xs; boring. +Qed. - simpl nth_error; erewrite IHm by auto; clear IHm. - destruct (eq_nat_dec n m), (eq_nat_dec (S n) (S m)); nth_tac. +Hint Rewrite @length_update_nth : distr_length. + +(** TODO: this is in the stdlib in 8.5; remove this when we move to 8.5-only *) +Lemma nth_error_None : forall (A : Type) (l : list A) (n : nat), nth_error l n = None <-> length l <= n. +Proof. + intros A l n. + destruct (le_lt_dec (length l) n) as [H|H]; + split; intro H'; + try omega; + try (apply nth_error_length_error in H; tauto); + try (apply nth_error_error_length in H'; omega). Qed. -Lemma length_set_nth : forall {T} i (x:T) xs, length (set_nth i x xs) = length xs. - induction i, xs; boring. +(** TODO: this is in the stdlib in 8.5; remove this when we move to 8.5-only *) +Lemma nth_error_Some : forall (A : Type) (l : list A) (n : nat), nth_error l n <> None <-> n < length l. +Proof. intros; rewrite nth_error_None; split; omega. Qed. + +Lemma nth_set_nth : forall m {T} (xs:list T) (n:nat) x, + nth_error (set_nth m x xs) n = + if eq_nat_dec n m + then (if lt_dec n (length xs) then Some x else None) + else nth_error xs n. +Proof. + intros; unfold set_nth; rewrite nth_update_nth. + destruct (nth_error xs n) eqn:?, (lt_dec n (length xs)) as [p|p]; + rewrite <- nth_error_Some in p; + solve [ reflexivity + | exfalso; apply p; congruence ]. Qed. +Hint Rewrite @nth_set_nth : push_nth_error. + +Lemma length_set_nth : forall {T} i x (xs:list T), length (set_nth i x xs) = length xs. +Proof. intros; apply length_update_nth. Qed. + +Hint Rewrite @length_set_nth : distr_length. + Lemma nth_error_length_exists_value : forall {A} (i : nat) (xs : list A), (i < length xs)%nat -> exists x, nth_error xs i = Some x. Proof. @@ -143,52 +329,88 @@ Proof. destruct (nth_error_length_exists_value i xs); intuition; congruence. Qed. -Lemma nth_error_value_eq_nth_default : forall {T} i xs (x d:T), +Lemma nth_error_value_eq_nth_default : forall {T} i (x : T) xs, nth_error xs i = Some x -> forall d, nth_default d xs i = x. Proof. unfold nth_default; boring. Qed. +Hint Rewrite @nth_error_value_eq_nth_default using eassumption : simpl_nth_default. + Lemma skipn0 : forall {T} (xs:list T), skipn 0 xs = xs. +Proof. auto. Qed. + +Lemma firstn0 : forall {T} (xs:list T), firstn 0 xs = nil. +Proof. auto. Qed. + +Lemma splice_nth_equiv_update_nth : forall {T} n f d (xs:list T), + splice_nth n (f (nth_default d xs n)) xs = + if lt_dec n (length xs) + then update_nth n f xs + else xs ++ (f d)::nil. Proof. - auto. + induction n, xs; boring_list. + do 2 break_if; auto; omega. Qed. -Lemma firstn0 : forall {T} (xs:list T), firstn 0 xs = nil. +Lemma splice_nth_equiv_update_nth_update : forall {T} n f d (xs:list T), + n < length xs -> + splice_nth n (f (nth_default d xs n)) xs = update_nth n f xs. Proof. - auto. + intros. + rewrite splice_nth_equiv_update_nth. + break_if; auto; omega. Qed. -Definition splice_nth {T} n (x:T) xs := firstn n xs ++ x :: skipn (S n) xs. -Hint Unfold splice_nth. +Lemma splice_nth_equiv_update_nth_snoc : forall {T} n f d (xs:list T), + n >= length xs -> + splice_nth n (f (nth_default d xs n)) xs = xs ++ (f d)::nil. +Proof. + intros. + rewrite splice_nth_equiv_update_nth. + break_if; auto; omega. +Qed. + +Definition IMPOSSIBLE {T} : list T. exact nil. Qed. + +Ltac remove_nth_error := + repeat match goal with + | _ => exfalso; solve [ eauto using @nth_error_length_not_error ] + | [ |- context[match nth_error ?ls ?n with _ => _ end] ] + => destruct (nth_error ls n) eqn:? + end. + +Lemma update_nth_equiv_splice_nth: forall {T} n f (xs:list T), + update_nth n f xs = + if lt_dec n (length xs) + then match nth_error xs n with + | Some v => splice_nth n (f v) xs + | None => IMPOSSIBLE + end + else xs. +Proof. + induction n; destruct xs; intros; + autorewrite with simpl_update_nth simpl_nth_default in *; simpl in *; + try (erewrite IHn; clear IHn); auto. + repeat break_match; remove_nth_error; try reflexivity; try omega. +Qed. Lemma splice_nth_equiv_set_nth : forall {T} n x (xs:list T), splice_nth n x xs = if lt_dec n (length xs) then set_nth n x xs else xs ++ x::nil. -Proof. - induction n, xs; boring. - break_if; break_if; auto; omega. -Qed. +Proof. intros; rewrite splice_nth_equiv_update_nth with (f := fun _ => x); auto. Qed. Lemma splice_nth_equiv_set_nth_set : forall {T} n x (xs:list T), n < length xs -> splice_nth n x xs = set_nth n x xs. -Proof. - intros. - rewrite splice_nth_equiv_set_nth. - break_if; auto; omega. -Qed. +Proof. intros; rewrite splice_nth_equiv_update_nth_update with (f := fun _ => x); auto. Qed. Lemma splice_nth_equiv_set_nth_snoc : forall {T} n x (xs:list T), n >= length xs -> splice_nth n x xs = xs ++ x::nil. -Proof. - intros. - rewrite splice_nth_equiv_set_nth. - break_if; auto; omega. -Qed. +Proof. intros; rewrite splice_nth_equiv_update_nth_snoc with (f := fun _ => x); auto. Qed. Lemma set_nth_equiv_splice_nth: forall {T} n x (xs:list T), set_nth n x xs = @@ -196,11 +418,48 @@ Lemma set_nth_equiv_splice_nth: forall {T} n x (xs:list T), then splice_nth n x xs else xs. Proof. - induction n; destruct xs; intros; simpl in *; - try (rewrite IHn; clear IHn); auto. - break_if; break_if; auto; omega. + intros; unfold set_nth; rewrite update_nth_equiv_splice_nth with (f := fun _ => x); auto. + repeat break_match; remove_nth_error; trivial. Qed. +Lemma combine_update_nth : forall {A B} n f g (xs:list A) (ys:list B), + combine (update_nth n f xs) (update_nth n g ys) = + update_nth n (fun xy => (f (fst xy), g (snd xy))) (combine xs ys). +Proof. + induction n; destruct xs, ys; simpl; try rewrite IHn; reflexivity. +Qed. + +(* grumble, grumble, [rewrite] is bad at inferring the identity function, and constant functions *) +Ltac rewrite_rev_combine_update_nth := + let lem := match goal with + | [ |- appcontext[update_nth ?n (fun xy => (@?f xy, @?g xy)) (combine ?xs ?ys)] ] + => let f := match (eval cbv [fst] in (fun y x => f (x, y))) with + | fun _ => ?f => f + end in + let g := match (eval cbv [snd] in (fun x y => g (x, y))) with + | fun _ => ?g => g + end in + constr:(@combine_update_nth _ _ n f g xs ys) + end in + rewrite <- lem. + +Lemma combine_update_nth_l : forall {A B} n (f : A -> A) xs (ys:list B), + combine (update_nth n f xs) ys = + update_nth n (fun xy => (f (fst xy), snd xy)) (combine xs ys). +Proof. + intros ??? f xs ys. + etransitivity; [ | apply combine_update_nth with (g := fun x => x) ]. + rewrite update_nth_id; reflexivity. +Qed. + +Lemma combine_update_nth_r : forall {A B} n (g : B -> B) (xs:list A) (ys:list B), + combine xs (update_nth n g ys) = + update_nth n (fun xy => (fst xy, g (snd xy))) (combine xs ys). +Proof. + intros ??? g xs ys. + etransitivity; [ | apply combine_update_nth with (f := fun x => x) ]. + rewrite update_nth_id; reflexivity. +Qed. Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B), combine (set_nth n x xs) ys = @@ -209,12 +468,12 @@ Lemma combine_set_nth : forall {A B} n (x:A) xs (ys:list B), | Some y => set_nth n (x,y) (combine xs ys) end. Proof. - (* TODO(andreser): this proof can totally be automated, but requires writing ltac that vets multiple hypotheses at once *) - induction n, xs, ys; nth_tac; try rewrite IHn; nth_tac; - try (f_equal; specialize (IHn x xs ys ); rewrite H in IHn; rewrite <- IHn); - try (specialize (nth_error_value_length _ _ _ _ H); omega). - assert (Some b0=Some b1) as HA by (rewrite <-H, <-H0; auto). - injection HA; intros; subst; auto. + intros; unfold set_nth; rewrite combine_update_nth_l. + nth_tac; + [ repeat rewrite_rev_combine_update_nth; apply f_equal2 + | assert (nth_error (combine xs ys) n = None) + by (apply nth_error_None; rewrite combine_length; omega * ) ]; + autorewrite with simpl_update_nth; reflexivity. Qed. Lemma nth_error_value_In : forall {T} n xs (x:T), @@ -258,6 +517,8 @@ Proof. destruct (lt_dec n (length xs)); auto. Qed. +Hint Rewrite @nth_default_app : push_nth_default. + Lemma combine_truncate_r : forall {A B} (xs : list A) (ys : list B), combine xs ys = combine xs (firstn (length xs) ys). Proof. @@ -278,14 +539,34 @@ Proof. Qed. Lemma firstn_nil : forall {A} n, firstn n nil = @nil A. -Proof. - destruct n; auto. -Qed. +Proof. destruct n; auto. Qed. + +Hint Rewrite @firstn_nil : simpl_firstn. Lemma skipn_nil : forall {A} n, skipn n nil = @nil A. -Proof. - destruct n; auto. -Qed. +Proof. destruct n; auto. Qed. + +Hint Rewrite @skipn_nil : simpl_skipn. + +Lemma firstn_0 : forall {A} xs, @firstn A 0 xs = nil. +Proof. reflexivity. Qed. + +Hint Rewrite @firstn_0 : simpl_firstn. + +Lemma skipn_0 : forall {A} xs, @skipn A 0 xs = xs. +Proof. reflexivity. Qed. + +Hint Rewrite @skipn_0 : simpl_skipn. + +Lemma firstn_cons_S : forall {A} n x xs, @firstn A (S n) (x::xs) = x::@firstn A n xs. +Proof. reflexivity. Qed. + +Hint Rewrite @firstn_cons_S : simpl_firstn. + +Lemma skipn_cons_S : forall {A} n x xs, @skipn A (S n) (x::xs) = @skipn A n xs. +Proof. reflexivity. Qed. + +Hint Rewrite @skipn_cons_S : simpl_skipn. Lemma firstn_app : forall {A} n (xs ys : list A), firstn n (xs ++ ys) = firstn n xs ++ firstn (n - length xs) ys. @@ -351,15 +632,12 @@ Proof. reflexivity. Qed. +Hint Rewrite @fold_right_cons : simpl_fold_right. + Lemma length_cons : forall {T} (x:T) xs, length (x::xs) = S (length xs). reflexivity. Qed. -Lemma S_pred_nonzero : forall a, (a > 0 -> S (pred a) = a)%nat. -Proof. - destruct a; omega. -Qed. - Lemma cons_length : forall A (xs : list A) a, length (a :: xs) = S (length xs). Proof. auto. @@ -428,17 +706,6 @@ Proof. auto. Qed. -Lemma nth_default_cons : forall {T} (x u0 : T) us, nth_default x (u0 :: us) 0 = u0. -Proof. - auto. -Qed. - -Lemma nth_default_cons_S : forall {A} us (u0 : A) n d, - nth_default d (u0 :: us) (S n) = nth_default d us n. -Proof. - boring. -Qed. - Lemma nth_error_Some_nth_default : forall {T} i x (l : list T), (i < length l)%nat -> nth_error l i = Some (nth_default x l i). Proof. @@ -449,47 +716,52 @@ Proof. reflexivity. Qed. +Lemma update_nth_cons : forall {T} f (u0 : T) us, update_nth 0 f (u0 :: us) = (f u0) :: us. +Proof. reflexivity. Qed. + +Hint Rewrite @update_nth_cons : simpl_update_nth. + Lemma set_nth_cons : forall {T} (x u0 : T) us, set_nth 0 x (u0 :: us) = x :: us. -Proof. - auto. -Qed. +Proof. intros; apply update_nth_cons. Qed. + +Hint Rewrite @set_nth_cons : simpl_set_nth. -Create HintDb distr_length discriminated. Hint Rewrite @nil_length0 @length_cons - @app_length - @rev_length - @map_length - @seq_length - @fold_left_length - @split_length_l - @split_length_r - @firstn_length @skipn_length - @combine_length - @prod_length + @length_update_nth @length_set_nth : distr_length. Ltac distr_length := autorewrite with distr_length in *; try solve [simpl in *; omega]. -Lemma cons_set_nth : forall {T} n (x y : T) us, - y :: set_nth n x us = set_nth (S n) x (y :: us). +Lemma cons_update_nth : forall {T} n f (y : T) us, + y :: update_nth n f us = update_nth (S n) f (y :: us). Proof. induction n; boring. Qed. -Lemma set_nth_nil : forall {T} n (x : T), set_nth n x nil = nil. -Proof. - induction n; boring. -Qed. +Hint Rewrite <- @cons_update_nth : simpl_update_nth. -Lemma nth_default_nil : forall {T} n (d : T), nth_default d nil n = d. +Lemma update_nth_nil : forall {T} n f, update_nth n f (@nil T) = @nil T. Proof. induction n; boring. Qed. +Hint Rewrite @update_nth_nil : simpl_update_nth. + +Lemma cons_set_nth : forall {T} n (x y : T) us, + y :: set_nth n x us = set_nth (S n) x (y :: us). +Proof. intros; apply cons_update_nth. Qed. + +Hint Rewrite <- @cons_set_nth : simpl_set_nth. + +Lemma set_nth_nil : forall {T} n (x : T), set_nth n x nil = nil. +Proof. intros; apply update_nth_nil. Qed. + +Hint Rewrite @set_nth_nil : simpl_set_nth. + Lemma skipn_nth_default : forall {T} n us (d : T), (n < length us)%nat -> skipn n us = nth_default d us n :: skipn (S n) us. Proof. @@ -508,6 +780,8 @@ Proof. congruence. Qed. +Hint Rewrite @nth_default_out_of_bounds using omega : simpl_nth_default. + Ltac nth_error_inbounds := match goal with | [ |- context[match nth_error ?xs ?i with Some _ => _ | None => _ end ] ] => @@ -535,11 +809,20 @@ Ltac set_nth_inbounds := match goal with | [ H : ~ (i < (length xs))%nat |- _ ] => destruct H | [ H : (i < (length xs))%nat |- _ ] => try solve [distr_length] - end; - idtac + end + end. +Ltac update_nth_inbounds := + match goal with + | [ |- context[update_nth ?i ?f ?xs] ] => + rewrite (update_nth_equiv_splice_nth i f xs); + destruct (lt_dec i (length xs)); + match goal with + | [ H : ~ (i < (length xs))%nat |- _ ] => destruct H + | [ H : (i < (length xs))%nat |- _ ] => remove_nth_error; try solve [distr_length] + end end. -Ltac nth_inbounds := nth_error_inbounds || set_nth_inbounds. +Ltac nth_inbounds := nth_error_inbounds || set_nth_inbounds || update_nth_inbounds. Lemma cons_eq_head : forall {T} (x y:T) xs ys, x::xs = y::ys -> x=y. Proof. @@ -557,6 +840,8 @@ Proof. nth_tac. Qed. +Hint Rewrite @map_nth_default_always : push_nth_default. + Lemma fold_right_and_True_forall_In_iff : forall {T} (l : list T) (P : T -> Prop), (forall x, In x l -> P x) <-> fold_right and True (map P l). Proof. @@ -617,11 +902,236 @@ Proof. omega. Qed. +Lemma update_nth_out_of_bounds : forall {A} n f xs, n >= length xs -> @update_nth A n f xs = xs. +Proof. + induction n; destruct xs; simpl; try congruence; try omega; intros. + rewrite IHn by omega; reflexivity. +Qed. + +Hint Rewrite @update_nth_out_of_bounds using omega : simpl_update_nth. + + +Lemma update_nth_nth_default_full : forall {A} (d:A) n f l i, + nth_default d (update_nth n f l) i = + if lt_dec i (length l) then + if (eq_nat_dec i n) then f (nth_default d l i) + else nth_default d l i + else d. +Proof. + induction n; (destruct l; simpl in *; [ intros; destruct i; simpl; try reflexivity; omega | ]); + intros; repeat break_if; subst; try destruct i; + repeat first [ progress break_if + | progress subst + | progress boring + | progress autorewrite with simpl_nth_default + | omega ]. +Qed. + +Hint Rewrite @update_nth_nth_default_full : push_nth_default. + +Lemma update_nth_nth_default : forall {A} (d:A) n f l i, (0 <= i < length l)%nat -> + nth_default d (update_nth n f l) i = + if (eq_nat_dec i n) then f (nth_default d l i) else nth_default d l i. +Proof. intros; rewrite update_nth_nth_default_full; repeat break_if; boring. Qed. + +Hint Rewrite @update_nth_nth_default using (omega || distr_length; omega) : push_nth_default. + +Lemma set_nth_nth_default_full : forall {A} (d:A) n v l i, + nth_default d (set_nth n v l) i = + if lt_dec i (length l) then + if (eq_nat_dec i n) then v + else nth_default d l i + else d. +Proof. intros; apply update_nth_nth_default_full; assumption. Qed. + +Hint Rewrite @set_nth_nth_default_full : push_nth_default. + Lemma set_nth_nth_default : forall {A} (d:A) n x l i, (0 <= i < length l)%nat -> nth_default d (set_nth n x l) i = if (eq_nat_dec i n) then x else nth_default d l i. +Proof. intros; apply update_nth_nth_default; assumption. Qed. + +Hint Rewrite @set_nth_nth_default using (omega || distr_length; omega) : push_nth_default. + +Lemma nth_default_preserves_properties : forall {A} (P : A -> Prop) l n d, + (forall x, In x l -> P x) -> P d -> P (nth_default d l n). +Proof. + intros; rewrite nth_default_eq. + destruct (nth_in_or_default n l d); auto. + congruence. +Qed. + +Lemma nth_error_first : forall {T} (a b : T) l, + nth_error (a :: l) 0 = Some b -> a = b. +Proof. + intros; simpl in *. + unfold value in *. + congruence. +Qed. + +Lemma nth_error_exists_first : forall {T} l (x : T) (H : nth_error l 0 = Some x), + exists l', l = x :: l'. +Proof. + induction l; try discriminate; eexists. + apply nth_error_first in H. + subst; eauto. +Qed. + +Lemma list_elementwise_eq : forall {T} (l1 l2 : list T), + (forall i, nth_error l1 i = nth_error l2 i) -> l1 = l2. +Proof. + induction l1, l2; intros; try reflexivity; + pose proof (H 0%nat) as Hfirst; simpl in Hfirst; inversion Hfirst. + f_equal. + apply IHl1. + intros i; specialize (H (S i)). + boring. +Qed. + +Lemma sum_firstn_all_succ : forall n l, (length l <= n)%nat -> + sum_firstn l (S n) = sum_firstn l n. +Proof. + unfold sum_firstn; intros. + rewrite !firstn_all_strong by omega. + congruence. +Qed. + +Hint Rewrite @sum_firstn_all_succ using omega : simpl_sum_firstn. + +Lemma sum_firstn_succ_default : forall l i, + sum_firstn l (S i) = (nth_default 0 l i + sum_firstn l i)%Z. +Proof. + unfold sum_firstn; induction l, i; + intros; autorewrite with simpl_nth_default simpl_firstn simpl_fold_right in *; + try reflexivity. + rewrite IHl; omega. +Qed. + +Hint Rewrite @sum_firstn_succ_default : simpl_sum_firstn. + +Lemma sum_firstn_0 : forall xs, + sum_firstn xs 0 = 0%Z. +Proof. + destruct xs; reflexivity. +Qed. + +Hint Rewrite @sum_firstn_0 : simpl_sum_firstn. + +Lemma sum_firstn_succ : forall l i x, + nth_error l i = Some x -> + sum_firstn l (S i) = (x + sum_firstn l i)%Z. +Proof. + intros; rewrite sum_firstn_succ_default. + erewrite nth_error_value_eq_nth_default by eassumption; reflexivity. +Qed. + +Hint Rewrite @sum_firstn_succ using congruence : simpl_sum_firstn. + +Lemma sum_firstn_succ_default_rev : forall l i, + sum_firstn l i = (sum_firstn l (S i) - nth_default 0 l i)%Z. +Proof. + intros; rewrite sum_firstn_succ_default; omega. +Qed. + +Lemma sum_firstn_succ_rev : forall l i x, + nth_error l i = Some x -> + sum_firstn l i = (sum_firstn l (S i) - x)%Z. +Proof. + intros; erewrite sum_firstn_succ by eassumption; omega. +Qed. + +Lemma nth_default_map2 : forall {A B C} (f : A -> B -> C) ls1 ls2 i d d1 d2, + nth_default d (map2 f ls1 ls2) i = + if lt_dec i (min (length ls1) (length ls2)) + then f (nth_default d1 ls1 i) (nth_default d2 ls2 i) + else d. +Proof. + induction ls1, ls2. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + cbv [map2 length min]. + intros. + break_if; try omega. + apply nth_default_nil. + + simpl. + destruct i. + - intros. rewrite !nth_default_cons. + break_if; auto; omega. + - intros. rewrite !nth_default_cons_S. + rewrite IHls1 with (d1 := d1) (d2 := d2). + repeat break_if; auto; omega. +Qed. + +Lemma map2_cons : forall A B C (f : A -> B -> C) ls1 ls2 a b, + map2 f (a :: ls1) (b :: ls2) = f a b :: map2 f ls1 ls2. +Proof. + reflexivity. +Qed. + +Lemma map2_nil_l : forall A B C (f : A -> B -> C) ls2, + map2 f nil ls2 = nil. +Proof. + reflexivity. +Qed. + +Lemma map2_nil_r : forall A B C (f : A -> B -> C) ls1, + map2 f ls1 nil = nil. +Proof. + destruct ls1; reflexivity. +Qed. +Local Hint Resolve map2_nil_r map2_nil_l. + +Opaque map2. + +Lemma map2_length : forall A B C (f : A -> B -> C) ls1 ls2, + length (map2 f ls1 ls2) = min (length ls1) (length ls2). +Proof. + induction ls1, ls2; intros; try solve [cbv; auto]. + rewrite map2_cons, !length_cons, IHls1. + auto. +Qed. + +Ltac simpl_list_lengths := repeat match goal with + | H : appcontext[length (@nil ?A)] |- _ => rewrite (@nil_length0 A) in H + | H : appcontext[length (_ :: _)] |- _ => rewrite length_cons in H + | |- appcontext[length (@nil ?A)] => rewrite (@nil_length0 A) + | |- appcontext[length (_ :: _)] => rewrite length_cons + end. + +Lemma map2_app : forall A B C (f : A -> B -> C) ls1 ls2 ls1' ls2', + (length ls1 = length ls2) -> + map2 f (ls1 ++ ls1') (ls2 ++ ls2') = map2 f ls1 ls2 ++ map2 f ls1' ls2'. +Proof. + induction ls1, ls2; intros; rewrite ?map2_nil_r, ?app_nil_l; try congruence; + simpl_list_lengths; try omega. + rewrite <-!app_comm_cons, !map2_cons. + rewrite IHls1; auto. +Qed. + +Lemma firstn_update_nth {A} + : forall f m n (xs : list A), firstn m (update_nth n f xs) = update_nth n f (firstn m xs). +Proof. + induction m; destruct n, xs; + autorewrite with simpl_firstn simpl_update_nth; + congruence. +Qed. + +Hint Rewrite @firstn_update_nth : push_firstn. +Hint Rewrite @firstn_update_nth : pull_update_nth. +Hint Rewrite <- @firstn_update_nth : pull_firstn. +Hint Rewrite <- @firstn_update_nth : push_update_nth. + +Require Import Coq.Lists.SetoidList. +Global Instance Proper_nth_default : forall A eq, + Proper (eq==>eqlistA eq==>Logic.eq==>eq) (nth_default (A:=A)). Proof. - induction n; (destruct l; [intros; simpl in *; omega | ]); simpl; - destruct i; break_if; try omega; intros; try apply nth_default_cons; - rewrite !nth_default_cons_S, ?IHn; try break_if; omega || reflexivity. + do 5 intro; subst; induction 1. + + repeat intro; rewrite !nth_default_nil; assumption. + + repeat intro; subst; destruct y0; rewrite ?nth_default_cons, ?nth_default_cons_S; auto. Qed.
\ No newline at end of file diff --git a/src/Util/NatUtil.v b/src/Util/NatUtil.v index 0cdfd784f..83375f99a 100644 --- a/src/Util/NatUtil.v +++ b/src/Util/NatUtil.v @@ -1,9 +1,60 @@ +Require Coq.Logic.Eqdep_dec. Require Import Coq.Numbers.Natural.Peano.NPeano Coq.omega.Omega. Require Import Coq.micromega.Psatz. Import Nat. +Create HintDb natsimplify discriminated. + +Hint Resolve mod_bound_pos : arith. +Hint Resolve (fun x y p q => proj1 (@Nat.mod_bound_pos x y p q)) (fun x y p q => proj2 (@Nat.mod_bound_pos x y p q)) : arith. + +Hint Rewrite @mod_small @mod_mod @mod_1_l @mod_1_r succ_pred using omega : natsimplify. + Local Open Scope nat_scope. +Lemma min_def {x y} : min x y = x - (x - y). +Proof. apply Min.min_case_strong; omega. Qed. +Lemma max_def {x y} : max x y = x + (y - x). +Proof. apply Max.max_case_strong; omega. Qed. +Ltac coq_omega := omega. +Ltac handle_min_max_for_omega_gen min max := + repeat match goal with + | [ H : context[min _ _] |- _ ] => rewrite !min_def in H || setoid_rewrite min_def in H + | [ H : context[max _ _] |- _ ] => rewrite !max_def in H || setoid_rewrite max_def in H + | [ |- context[min _ _] ] => rewrite !min_def || setoid_rewrite min_def + | [ |- context[max _ _] ] => rewrite !max_def || setoid_rewrite max_def + end. +Ltac handle_min_max_for_omega_case_gen min max := + repeat match goal with + | [ H : context[min _ _] |- _ ] => revert H + | [ H : context[max _ _] |- _ ] => revert H + | [ |- context[min _ _] ] => apply Min.min_case_strong + | [ |- context[max _ _] ] => apply Max.max_case_strong + end; + intros. +Ltac handle_min_max_for_omega := handle_min_max_for_omega_gen min max. +Ltac handle_min_max_for_omega_case := handle_min_max_for_omega_case_gen min max. +(* In 8.4, Nat.min is a definition, so we need to unfold it *) +Ltac handle_min_max_for_omega_compat_84 := + let min := (eval cbv [min] in min) in + let max := (eval cbv [max] in max) in + handle_min_max_for_omega_gen min max. +Ltac handle_min_max_for_omega_case_compat_84 := + let min := (eval cbv [min] in min) in + let max := (eval cbv [max] in max) in + handle_min_max_for_omega_case_gen min max. +Ltac omega_with_min_max := + handle_min_max_for_omega; + try handle_min_max_for_omega_compat_84; + omega. +Ltac omega_with_min_max_case := + handle_min_max_for_omega_case; + try handle_min_max_for_omega_case_compat_84; + omega. +Tactic Notation "omega" := coq_omega. +Tactic Notation "omega" "*" := omega_with_min_max_case. +Tactic Notation "omega" "**" := omega_with_min_max. + Lemma div_minus : forall a b, b <> 0 -> (a + b) / b = a / b + 1. Proof. intros. @@ -86,3 +137,53 @@ Proof. Qed. Hint Resolve pow_nonzero : arith. + +Lemma S_pred_nonzero : forall a, (a > 0 -> S (pred a) = a)%nat. +Proof. + destruct a; simpl; omega. +Qed. + +Hint Rewrite S_pred_nonzero using omega : natsimplify. + +Lemma mod_same_eq a b : a <> 0 -> a = b -> b mod a = 0. +Proof. intros; subst; apply mod_same; assumption. Qed. + +Hint Rewrite @mod_same_eq using omega : natsimplify. +Hint Resolve mod_same_eq : arith. + +Lemma mod_mod_eq a b c : a <> 0 -> b = c mod a -> b mod a = b. +Proof. intros; subst; autorewrite with natsimplify; reflexivity. Qed. + +Hint Rewrite @mod_mod_eq using (reflexivity || omega) : natsimplify. + +Local Arguments minus !_ !_. + +Lemma S_mod_full a b : a <> 0 -> (S b) mod a = if eq_nat_dec (S (b mod a)) a + then 0 + else S (b mod a). +Proof. + change (S b) with (1+b); intros. + pose proof (mod_bound_pos b a). + rewrite add_mod by assumption. + destruct (eq_nat_dec (S (b mod a)) a) as [H'|H']; + destruct a as [|[|a]]; autorewrite with natsimplify in *; + try congruence; try reflexivity. +Qed. + +Hint Rewrite S_mod_full using omega : natsimplify. + +Lemma S_mod a b : a <> 0 -> S (b mod a) <> a -> (S b) mod a = S (b mod a). +Proof. + intros; rewrite S_mod_full by assumption. + edestruct eq_nat_dec; omega. +Qed. + +Hint Rewrite S_mod using (omega || autorewrite with natsimplify; omega) : natsimplify. + +Lemma eq_nat_dec_refl x : eq_nat_dec x x = left (Logic.eq_refl x). +Proof. + edestruct eq_nat_dec; try congruence. + apply f_equal, Eqdep_dec.UIP_dec, eq_nat_dec. +Qed. + +Hint Rewrite eq_nat_dec_refl : natsimplify. diff --git a/src/Util/Notations.v b/src/Util/Notations.v index c3f776766..4526e6dce 100644 --- a/src/Util/Notations.v +++ b/src/Util/Notations.v @@ -16,8 +16,8 @@ Reserved Notation "x ^ 2" (at level 30, format "x ^ 2"). Reserved Notation "x ^ 3" (at level 30, format "x ^ 3"). Reserved Infix "mod" (at level 40, no associativity). Reserved Notation "'canonical' 'encoding' 'of' T 'as' B" (at level 50). -Reserved Infix "<<" (at level 50). -Reserved Infix "&" (at level 50). -Reserved Infix "<<" (at level 50). +Reserved Infix "<<" (at level 30, no associativity). +Reserved Infix ">>" (at level 30, no associativity). Reserved Infix "&" (at level 50). +Reserved Infix "∣" (at level 50). Reserved Infix "~=" (at level 70). diff --git a/src/Util/NumTheoryUtil.v b/src/Util/NumTheoryUtil.v index 10ce148b0..c16b87639 100644 --- a/src/Util/NumTheoryUtil.v +++ b/src/Util/NumTheoryUtil.v @@ -66,7 +66,7 @@ Qed. Lemma p_odd : Z.odd p = true. Proof. - pose proof (prime_odd_or_2 p prime_p). + pose proof (Z.prime_odd_or_2 p prime_p). destruct H; auto. Qed. @@ -124,12 +124,12 @@ Proof. assert (b mod p <> 0) as b_nonzero. { intuition. rewrite <- Z.pow_2_r in a_square. - rewrite mod_exp_0 in a_square by prime_bound. + rewrite Z.mod_exp_0 in a_square by prime_bound. rewrite <- a_square in a_nonzero. auto. } pose proof (squared_fermat_little b b_nonzero). - rewrite mod_pow in * by prime_bound. + rewrite Z.mod_pow in * by prime_bound. rewrite <- a_square. rewrite Z.mod_mod; prime_bound. Qed. @@ -172,10 +172,10 @@ Proof. intros. destruct (exists_primitive_root_power) as [y [in_ZPGroup_y [y_order gpow_y]]]; auto. destruct (gpow_y a a_range) as [j [j_range pow_y_j]]; clear gpow_y. - rewrite mod_pow in pow_a_x by prime_bound. + rewrite Z.mod_pow in pow_a_x by prime_bound. replace a with (a mod p) in pow_y_j by (apply Z.mod_small; omega). rewrite <- pow_y_j in pow_a_x. - rewrite <- mod_pow in pow_a_x by prime_bound. + rewrite <- Z.mod_pow in pow_a_x by prime_bound. rewrite <- Z.pow_mul_r in pow_a_x by omega. assert (p - 1 | j * x) as divide_mul_j_x. { rewrite <- phi_is_order in y_order. @@ -193,13 +193,13 @@ Proof. rewrite <- Z_div_plus by omega. rewrite Z.mul_comm. rewrite x_id_inv in divide_mul_j_x; auto. - apply (divide_mul_div _ j 2) in divide_mul_j_x; + apply (Z.divide_mul_div _ j 2) in divide_mul_j_x; try (apply prime_pred_divide2 || prime_bound); auto. rewrite <- Zdivide_Zdiv_eq by (auto || omega). rewrite Zplus_diag_eq_mult_2. replace (a mod p) with a in pow_y_j by (symmetry; apply Z.mod_small; omega). rewrite Z_div_mult by omega; auto. - apply divide2_even_iff. + apply Z.divide2_even_iff. apply prime_pred_even. Qed. @@ -281,7 +281,7 @@ Lemma div2_p_1mod4 : forall (p : Z) (prime_p : prime p) (neq_p_2: p <> 2), (p / 2) * 2 + 1 = p. Proof. intros. - destruct (prime_odd_or_2 p prime_p); intuition. + destruct (Z.prime_odd_or_2 p prime_p); intuition. rewrite <- Zdiv2_div. pose proof (Zdiv2_odd_eqn p); break_if; congruence || omega. Qed. diff --git a/src/Util/Option.v b/src/Util/Option.v index db4b69dde..2c11771ff 100644 --- a/src/Util/Option.v +++ b/src/Util/Option.v @@ -60,3 +60,12 @@ Ltac simpl_option_rect := (* deal with [option_rect _ _ _ None] and [option_rect | [ |- context[option_rect ?P ?S ?N (Some ?x) ] ] => change (option_rect P S N (Some x)) with (S x); cbv beta end. + +Definition option_eq {A} eq (x y : option A) := + match x with + | None => y = None + | Some ax => match y with + | None => False + | Some ay => eq ax ay + end + end. diff --git a/src/Util/Tactics.v b/src/Util/Tactics.v index ab98bb7f2..83ec603a0 100644 --- a/src/Util/Tactics.v +++ b/src/Util/Tactics.v @@ -7,6 +7,9 @@ Tactic Notation "test" tactic3(tac) := (** [not tac] is equivalent to [fail tac "succeeds"] if [tac] succeeds, and is equivalent to [idtac] if [tac] fails *) Tactic Notation "not" tactic3(tac) := try ((test tac); fail 1 tac "succeeds"). +Ltac get_goal := + match goal with |- ?G => G end. + (** find the head of the given expression *) Ltac head expr := match expr with @@ -270,3 +273,30 @@ Ltac side_conditions_before_to_side_conditions_after tac_in H := here, after evars are instantiated, and not above. *) move H after H'; clear H' | .. ]. + +(** Do something with every hypothesis. *) +Ltac do_with_hyp' tac := + match goal with + | [ H : _ |- _ ] => tac H + end. + +(** Rewrite with any applicable hypothesis. *) +Tactic Notation "rewrite_hyp" "*" := do_with_hyp' ltac:(fun H => rewrite H). +Tactic Notation "rewrite_hyp" "->" "*" := do_with_hyp' ltac:(fun H => rewrite -> H). +Tactic Notation "rewrite_hyp" "<-" "*" := do_with_hyp' ltac:(fun H => rewrite <- H). +Tactic Notation "rewrite_hyp" "?*" := repeat do_with_hyp' ltac:(fun H => rewrite !H). +Tactic Notation "rewrite_hyp" "->" "?*" := repeat do_with_hyp' ltac:(fun H => rewrite -> !H). +Tactic Notation "rewrite_hyp" "<-" "?*" := repeat do_with_hyp' ltac:(fun H => rewrite <- !H). +Tactic Notation "rewrite_hyp" "!*" := progress rewrite_hyp ?*. +Tactic Notation "rewrite_hyp" "->" "!*" := progress rewrite_hyp -> ?*. +Tactic Notation "rewrite_hyp" "<-" "!*" := progress rewrite_hyp <- ?*. + +Tactic Notation "rewrite_hyp" "*" "in" "*" := do_with_hyp' ltac:(fun H => rewrite H in * ). +Tactic Notation "rewrite_hyp" "->" "*" "in" "*" := do_with_hyp' ltac:(fun H => rewrite -> H in * ). +Tactic Notation "rewrite_hyp" "<-" "*" "in" "*" := do_with_hyp' ltac:(fun H => rewrite <- H in * ). +Tactic Notation "rewrite_hyp" "?*" "in" "*" := repeat do_with_hyp' ltac:(fun H => rewrite !H in * ). +Tactic Notation "rewrite_hyp" "->" "?*" "in" "*" := repeat do_with_hyp' ltac:(fun H => rewrite -> !H in * ). +Tactic Notation "rewrite_hyp" "<-" "?*" "in" "*" := repeat do_with_hyp' ltac:(fun H => rewrite <- !H in * ). +Tactic Notation "rewrite_hyp" "!*" "in" "*" := progress rewrite_hyp ?* in *. +Tactic Notation "rewrite_hyp" "->" "!*" "in" "*" := progress rewrite_hyp -> ?* in *. +Tactic Notation "rewrite_hyp" "<-" "!*" "in" "*" := progress rewrite_hyp <- ?* in *. diff --git a/src/Util/Tuple.v b/src/Util/Tuple.v index 4232c7bf8..13f8bd386 100644 --- a/src/Util/Tuple.v +++ b/src/Util/Tuple.v @@ -166,7 +166,9 @@ end. Lemma from_list_default'_eq : forall {T} (d : T) xs n y pf, from_list_default' d y n xs = from_list' y n xs pf. Proof. - induction xs; destruct n; intros; simpl in *; congruence. + induction xs; destruct n; intros; simpl in *; + solve [ congruence (* 8.5 *) + | erewrite IHxs; reflexivity ]. (* 8.4 *) Qed. Lemma from_list_default_eq : forall {T} (d : T) xs n pf, diff --git a/src/Util/ZUtil.v b/src/Util/ZUtil.v index a8b18ffef..2bcbabaac 100644 --- a/src/Util/ZUtil.v +++ b/src/Util/ZUtil.v @@ -1,10 +1,15 @@ Require Import Coq.ZArith.Zpower Coq.ZArith.Znumtheory Coq.ZArith.ZArith Coq.ZArith.Zdiv. Require Import Coq.omega.Omega Coq.micromega.Psatz Coq.Numbers.Natural.Peano.NPeano Coq.Arith.Arith. Require Import Crypto.Util.NatUtil. +Require Import Crypto.Util.Notations. Require Import Coq.Lists.List. Import Nat. Local Open Scope Z. +Infix ">>" := Z.shiftr : Z_scope. +Infix "<<" := Z.shiftl : Z_scope. +Infix "&" := Z.land : Z_scope. + Hint Extern 1 => lia : lia. Hint Extern 1 => lra : lra. Hint Extern 1 => nia : nia. @@ -17,7 +22,7 @@ Hint Resolve (fun a b H => proj1 (Z.mod_pos_bound a b H)) (fun a b H => proj2 (Z this database. *) Create HintDb zsimplify discriminated. Hint Rewrite Z.div_1_r Z.mul_1_r Z.mul_1_l Z.sub_diag Z.mul_0_r Z.mul_0_l Z.add_0_l Z.add_0_r Z.opp_involutive Z.sub_0_r : zsimplify. -Hint Rewrite Z.div_mul Z.div_1_l Z.div_same Z.mod_same Z.div_small Z.mod_small Z.div_add Z.div_add_l using lia : zsimplify. +Hint Rewrite Z.div_mul Z.div_1_l Z.div_same Z.mod_same Z.div_small Z.mod_small Z.div_add Z.div_add_l Z.mod_add Z.div_0_l using lia : zsimplify. (** "push" means transform [-f x] to [f (-x)]; "pull" means go the other way *) Create HintDb push_Zopp discriminated. @@ -43,318 +48,326 @@ Hint Rewrite Z.div_small_iff using lia : zstrip_div. We'll put, e.g., [mul_div_eq] into it below. *) Create HintDb zstrip_div. -Lemma gt_lt_symmetry: forall n m, n > m <-> m < n. -Proof. - intros; split; omega. -Qed. - -Lemma positive_is_nonzero : forall x, x > 0 -> x <> 0. -Proof. - intros; omega. -Qed. -Hint Resolve positive_is_nonzero. - -Lemma div_positive_gt_0 : forall a b, a > 0 -> b > 0 -> a mod b = 0 -> - a / b > 0. -Proof. - intros; rewrite gt_lt_symmetry. - apply Z.div_str_pos. - split; intuition. - apply Z.divide_pos_le; try (apply Zmod_divide); omega. -Qed. - -Lemma elim_mod : forall a b m, a = b -> a mod m = b mod m. -Proof. - intros; subst; auto. -Qed. -Hint Resolve elim_mod. - -Lemma mod_mult_plus: forall a b c, (b <> 0) -> (a * b + c) mod b = c mod b. -Proof. - intros. - rewrite Zplus_mod. - rewrite Z.mod_mul; auto; simpl. - rewrite Zmod_mod; auto. -Qed. - -Lemma pos_pow_nat_pos : forall x n, - Z.pos x ^ Z.of_nat n > 0. - do 2 (intros; induction n; subst; simpl in *; auto with zarith). - rewrite <- Pos.add_1_r, Zpower_pos_is_exp. - apply Zmult_gt_0_compat; auto; reflexivity. -Qed. - -Lemma Z_div_mul' : forall a b : Z, b <> 0 -> (b * a) / b = a. - intros. rewrite Z.mul_comm. apply Z.div_mul; auto. -Qed. - -Hint Rewrite Z_div_mul' using lia : zsimplify. - -Lemma Zgt0_neq0 : forall x, x > 0 -> x <> 0. - intuition. -Qed. - -Lemma pow_Z2N_Zpow : forall a n, 0 <= a -> - ((Z.to_nat a) ^ n = Z.to_nat (a ^ Z.of_nat n)%Z)%nat. -Proof. - intros; induction n; try reflexivity. - rewrite Nat2Z.inj_succ. - rewrite pow_succ_r by apply le_0_n. - rewrite Z.pow_succ_r by apply Zle_0_nat. - rewrite IHn. - rewrite Z2Nat.inj_mul; auto using Z.pow_nonneg. -Qed. - -Lemma pow_Zpow : forall a n : nat, Z.of_nat (a ^ n) = Z.of_nat a ^ Z.of_nat n. -Proof with auto using Zle_0_nat, Z.pow_nonneg. - intros; apply Z2Nat.inj... - rewrite <- pow_Z2N_Zpow, !Nat2Z.id... -Qed. - -Lemma mod_exp_0 : forall a x m, x > 0 -> m > 1 -> a mod m = 0 -> - a ^ x mod m = 0. -Proof. - intros. - replace x with (Z.of_nat (Z.to_nat x)) in * by (apply Z2Nat.id; omega). - induction (Z.to_nat x). { - simpl in *; omega. - } { - rewrite Nat2Z.inj_succ in *. - rewrite Z.pow_succ_r by omega. - rewrite Z.mul_mod by omega. - case_eq n; intros. { - subst. simpl. - rewrite Zmod_1_l by omega. - rewrite H1. - apply Zmod_0_l. +Module Z. + Definition pow2_mod n i := (n & (Z.ones i)). + + Lemma positive_is_nonzero : forall x, x > 0 -> x <> 0. + Proof. intros; omega. Qed. + + Hint Resolve positive_is_nonzero : zarith. + + Lemma div_positive_gt_0 : forall a b, a > 0 -> b > 0 -> a mod b = 0 -> + a / b > 0. + Proof. + intros; rewrite Z.gt_lt_iff. + apply Z.div_str_pos. + split; intuition. + apply Z.divide_pos_le; try (apply Zmod_divide); omega. + Qed. + + Lemma elim_mod : forall a b m, a = b -> a mod m = b mod m. + Proof. intros; subst; auto. Qed. + + Hint Resolve elim_mod : zarith. + + Lemma mod_add_l : forall a b c, b <> 0 -> (a * b + c) mod b = c mod b. + Proof. intros; rewrite (Z.add_comm _ c); autorewrite with zsimplify; reflexivity. Qed. + Hint Rewrite mod_add_l using lia : zsimplify. + + Lemma mod_add' : forall a b c, b <> 0 -> (a + b * c) mod b = a mod b. + Proof. intros; rewrite (Z.mul_comm _ c); autorewrite with zsimplify; reflexivity. Qed. + Lemma mod_add_l' : forall a b c, a <> 0 -> (a * b + c) mod a = c mod a. + Proof. intros; rewrite (Z.mul_comm _ b); autorewrite with zsimplify; reflexivity. Qed. + Hint Rewrite mod_add' mod_add_l' using lia : zsimplify. + + Lemma pos_pow_nat_pos : forall x n, + Z.pos x ^ Z.of_nat n > 0. + Proof. + do 2 (intros; induction n; subst; simpl in *; auto with zarith). + rewrite <- Pos.add_1_r, Zpower_pos_is_exp. + apply Zmult_gt_0_compat; auto; reflexivity. + Qed. + + Lemma div_mul' : forall a b : Z, b <> 0 -> (b * a) / b = a. + Proof. intros. rewrite Z.mul_comm. apply Z.div_mul; auto. Qed. + Hint Rewrite div_mul' using lia : zsimplify. + + (** TODO: Should we get rid of this duplicate? *) + Notation gt0_neq0 := positive_is_nonzero (only parsing). + + Lemma pow_Z2N_Zpow : forall a n, 0 <= a -> + ((Z.to_nat a) ^ n = Z.to_nat (a ^ Z.of_nat n)%Z)%nat. + Proof. + intros; induction n; try reflexivity. + rewrite Nat2Z.inj_succ. + rewrite pow_succ_r by apply le_0_n. + rewrite Z.pow_succ_r by apply Zle_0_nat. + rewrite IHn. + rewrite Z2Nat.inj_mul; auto using Z.pow_nonneg. + Qed. + + Lemma pow_Zpow : forall a n : nat, Z.of_nat (a ^ n) = Z.of_nat a ^ Z.of_nat n. + Proof with auto using Zle_0_nat, Z.pow_nonneg. + intros; apply Z2Nat.inj... + rewrite <- pow_Z2N_Zpow, !Nat2Z.id... + Qed. + + Lemma mod_exp_0 : forall a x m, x > 0 -> m > 1 -> a mod m = 0 -> + a ^ x mod m = 0. + Proof. + intros. + replace x with (Z.of_nat (Z.to_nat x)) in * by (apply Z2Nat.id; omega). + induction (Z.to_nat x). { + simpl in *; omega. + } { + rewrite Nat2Z.inj_succ in *. + rewrite Z.pow_succ_r by omega. + rewrite Z.mul_mod by omega. + case_eq n; intros. { + subst. simpl. + rewrite Zmod_1_l by omega. + rewrite H1. + apply Zmod_0_l. + } { + subst. + rewrite IHn by (rewrite Nat2Z.inj_succ in *; omega). + rewrite H1. + auto. + } + } + Qed. + + Lemma mod_pow : forall (a m b : Z), (0 <= b) -> (m <> 0) -> + a ^ b mod m = (a mod m) ^ b mod m. + Proof. + intros; rewrite <- (Z2Nat.id b) by auto. + induction (Z.to_nat b); auto. + rewrite Nat2Z.inj_succ. + do 2 rewrite Z.pow_succ_r by apply Nat2Z.is_nonneg. + rewrite Z.mul_mod by auto. + rewrite (Z.mul_mod (a mod m) ((a mod m) ^ Z.of_nat n) m) by auto. + rewrite <- IHn by auto. + rewrite Z.mod_mod by auto. + reflexivity. + Qed. + + Ltac divide_exists_mul := let k := fresh "k" in + match goal with + | [ H : (?a | ?b) |- _ ] => apply Z.mod_divide in H; try apply Zmod_divides in H; destruct H as [k H] + | [ |- (?a | ?b) ] => apply Z.mod_divide; try apply Zmod_divides + end; (omega || auto). + + Lemma divide_mul_div: forall a b c (a_nonzero : a <> 0) (c_nonzero : c <> 0), + (a | b * (a / c)) -> (c | a) -> (c | b). + Proof. + intros ? ? ? ? ? divide_a divide_c_a; do 2 divide_exists_mul. + rewrite divide_c_a in divide_a. + rewrite div_mul' in divide_a by auto. + replace (b * k) with (k * b) in divide_a by ring. + replace (c * k * k0) with (k * (k0 * c)) in divide_a by ring. + rewrite Z.mul_cancel_l in divide_a by (intuition; rewrite H in divide_c_a; ring_simplify in divide_a; intuition). + eapply Zdivide_intro; eauto. + Qed. + + Lemma divide2_even_iff : forall n, (2 | n) <-> Z.even n = true. + Proof. + intro; split. { + intro divide2_n. + divide_exists_mul; [ | pose proof (Z.mod_pos_bound n 2); omega]. + rewrite divide2_n. + apply Z.even_mul. } { - subst. - rewrite IHn by (rewrite Nat2Z.inj_succ in *; omega). - rewrite H1. - auto. + intro n_even. + pose proof (Zmod_even n). + rewrite n_even in H. + apply Zmod_divide; omega || auto. } - } -Qed. - -Lemma mod_pow : forall (a m b : Z), (0 <= b) -> (m <> 0) -> - a ^ b mod m = (a mod m) ^ b mod m. -Proof. - intros; rewrite <- (Z2Nat.id b) by auto. - induction (Z.to_nat b); auto. - rewrite Nat2Z.inj_succ. - do 2 rewrite Z.pow_succ_r by apply Nat2Z.is_nonneg. - rewrite Z.mul_mod by auto. - rewrite (Z.mul_mod (a mod m) ((a mod m) ^ Z.of_nat n) m) by auto. - rewrite <- IHn by auto. - rewrite Z.mod_mod by auto. - reflexivity. -Qed. - -Ltac Zdivide_exists_mul := let k := fresh "k" in -match goal with -| [ H : (?a | ?b) |- _ ] => apply Z.mod_divide in H; try apply Zmod_divides in H; destruct H as [k H] -| [ |- (?a | ?b) ] => apply Z.mod_divide; try apply Zmod_divides -end; (omega || auto). - -Lemma divide_mul_div: forall a b c (a_nonzero : a <> 0) (c_nonzero : c <> 0), - (a | b * (a / c)) -> (c | a) -> (c | b). -Proof. - intros ? ? ? ? ? divide_a divide_c_a; do 2 Zdivide_exists_mul. - rewrite divide_c_a in divide_a. - rewrite Z_div_mul' in divide_a by auto. - replace (b * k) with (k * b) in divide_a by ring. - replace (c * k * k0) with (k * (k0 * c)) in divide_a by ring. - rewrite Z.mul_cancel_l in divide_a by (intuition; rewrite H in divide_c_a; ring_simplify in divide_a; intuition). - eapply Zdivide_intro; eauto. -Qed. - -Lemma divide2_even_iff : forall n, (2 | n) <-> Z.even n = true. -Proof. - intro; split. { - intro divide2_n. - Zdivide_exists_mul; [ | pose proof (Z.mod_pos_bound n 2); omega]. - rewrite divide2_n. - apply Z.even_mul. - } { - intro n_even. - pose proof (Zmod_even n). - rewrite n_even in H. - apply Zmod_divide; omega || auto. - } -Qed. - -Lemma prime_odd_or_2 : forall p (prime_p : prime p), p = 2 \/ Z.odd p = true. -Proof. - intros. - apply Decidable.imp_not_l; try apply Z.eq_decidable. - intros p_neq2. - pose proof (Zmod_odd p) as mod_odd. - destruct (Sumbool.sumbool_of_bool (Z.odd p)) as [? | p_not_odd]; auto. - rewrite p_not_odd in mod_odd. - apply Zmod_divides in mod_odd; try omega. - destruct mod_odd as [c c_id]. - rewrite Z.mul_comm in c_id. - apply Zdivide_intro in c_id. - apply prime_divisors in c_id; auto. - destruct c_id; [omega | destruct H; [omega | destruct H; auto]]. - pose proof (prime_ge_2 p prime_p); omega. -Qed. - -Lemma mul_div_eq : (forall a m, m > 0 -> m * (a / m) = (a - a mod m))%Z. -Proof. - intros. - rewrite (Z_div_mod_eq a m) at 2 by auto. - ring. -Qed. - -Lemma mul_div_eq' : (forall a m, m > 0 -> (a / m) * m = (a - a mod m))%Z. -Proof. - intros. - rewrite (Z_div_mod_eq a m) at 2 by auto. - ring. -Qed. - -Hint Rewrite mul_div_eq mul_div_eq' using lia : zdiv_to_mod. -Hint Rewrite <- mul_div_eq' using lia : zmod_to_div. - -Ltac prime_bound := match goal with -| [ H : prime ?p |- _ ] => pose proof (prime_ge_2 p H); try omega -end. - -Lemma Zlt_minus_lt_0 : forall n m, m < n -> 0 < n - m. -Proof. - intros; omega. -Qed. - - -Lemma Z_testbit_low : forall n x i, (0 <= i < n) -> - Z.testbit x i = Z.testbit (Z.land x (Z.ones n)) i. -Proof. - intros. - rewrite Z.land_ones by omega. - symmetry. - apply Z.mod_pow2_bits_low. - omega. -Qed. - - -Lemma Z_testbit_shiftl : forall i, (0 <= i) -> forall a b n, (i < n) -> - Z.testbit (a + Z.shiftl b n) i = Z.testbit a i. -Proof. - intros. - erewrite Z_testbit_low; eauto. - rewrite Z.land_ones, Z.shiftl_mul_pow2 by omega. - rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 n); omega). - auto using Z.mod_pow2_bits_low. -Qed. - -Lemma Z_mod_div_eq0 : forall a b, 0 < b -> (a mod b) / b = 0. -Proof. - intros. - apply Z.div_small. - auto using Z.mod_pos_bound. -Qed. - -Lemma Z_shiftr_add_land : forall n m a b, (n <= m)%nat -> - Z.shiftr ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.of_nat m) = Z.shiftr b (Z.of_nat (m - n)). -Proof. - intros. - rewrite Z.land_ones by apply Nat2Z.is_nonneg. - rewrite !Z.shiftr_div_pow2 by apply Nat2Z.is_nonneg. - rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg. - rewrite (le_plus_minus n m) at 1 by assumption. - rewrite Nat2Z.inj_add. - rewrite Z.pow_add_r by apply Nat2Z.is_nonneg. - rewrite <- Z.div_div by first - [ pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega - | apply Z.pow_pos_nonneg; omega ]. - rewrite Z.div_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega). - rewrite Z_mod_div_eq0 by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat n)); omega); auto. -Qed. - -Lemma Z_land_add_land : forall n m a b, (m <= n)%nat -> - Z.land ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.ones (Z.of_nat m)) = Z.land a (Z.ones (Z.of_nat m)). -Proof. - intros. - rewrite !Z.land_ones by apply Nat2Z.is_nonneg. - rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg. - replace (b * 2 ^ Z.of_nat n) with - ((b * 2 ^ Z.of_nat (n - m)) * 2 ^ Z.of_nat m) by - (rewrite (le_plus_minus m n) at 2; try assumption; - rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg; ring). - rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat m)); omega). - symmetry. apply Znumtheory.Zmod_div_mod; try (apply Z.pow_pos_nonneg; omega). - rewrite (le_plus_minus m n) by assumption. - rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg. - apply Z.divide_factor_l. -Qed. - -Lemma Z_pow_gt0 : forall a, 0 < a -> forall b, 0 <= b -> 0 < a ^ b. -Proof. - intros until 1. - apply natlike_ind; try (simpl; omega). - intros. - rewrite Z.pow_succ_r by assumption. - apply Z.mul_pos_pos; assumption. -Qed. - -Lemma div_pow2succ : forall n x, (0 <= x) -> - n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x). -Proof. - intros. - rewrite Z.pow_succ_r, Z.mul_comm by auto. - rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega). - rewrite Zdiv2_div. - reflexivity. -Qed. - -Lemma shiftr_succ : forall n x, - Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1. -Proof. - intros. - rewrite Z.shiftr_shiftr by omega. - reflexivity. -Qed. - - -Definition Z_shiftl_by n a := Z.shiftl a n. - -Lemma Z_shiftl_by_mul_pow2 : forall n a, 0 <= n -> Z.mul (2 ^ n) a = Z_shiftl_by n a. -Proof. - intros. - unfold Z_shiftl_by. - rewrite Z.shiftl_mul_pow2 by assumption. - apply Z.mul_comm. -Qed. - -Lemma map_shiftl : forall n l, 0 <= n -> map (Z.mul (2 ^ n)) l = map (Z_shiftl_by n) l. -Proof. - intros; induction l; auto using Z_shiftl_by_mul_pow2. - simpl. - rewrite IHl. - f_equal. - apply Z_shiftl_by_mul_pow2. - assumption. -Qed. - -Lemma Z_odd_mod : forall a b, (b <> 0)%Z -> - Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a. -Proof. - intros. - rewrite Zmod_eq_full by assumption. - rewrite <-Z.add_opp_r, Z.odd_add, Z.odd_opp, Z.odd_mul. - case_eq (Z.odd b); intros; rewrite ?Bool.andb_true_r, ?Bool.andb_false_r; auto using Bool.xorb_false_r. -Qed. - -Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0. -Proof. - intros. - replace b with (b - c + c) by ring. - rewrite Z.pow_add_r by omega. - apply Z_mod_mult. -Qed. - - Lemma Z_ones_succ : forall x, (0 <= x) -> + Qed. + + Lemma prime_odd_or_2 : forall p (prime_p : prime p), p = 2 \/ Z.odd p = true. + Proof. + intros. + apply Decidable.imp_not_l; try apply Z.eq_decidable. + intros p_neq2. + pose proof (Zmod_odd p) as mod_odd. + destruct (Sumbool.sumbool_of_bool (Z.odd p)) as [? | p_not_odd]; auto. + rewrite p_not_odd in mod_odd. + apply Zmod_divides in mod_odd; try omega. + destruct mod_odd as [c c_id]. + rewrite Z.mul_comm in c_id. + apply Zdivide_intro in c_id. + apply prime_divisors in c_id; auto. + destruct c_id; [omega | destruct H; [omega | destruct H; auto]]. + pose proof (prime_ge_2 p prime_p); omega. + Qed. + + Lemma mul_div_eq : forall a m, m > 0 -> m * (a / m) = (a - a mod m). + Proof. + intros. + rewrite (Z_div_mod_eq a m) at 2 by auto. + ring. + Qed. + + Lemma mul_div_eq' : (forall a m, m > 0 -> (a / m) * m = (a - a mod m))%Z. + Proof. + intros. + rewrite (Z_div_mod_eq a m) at 2 by auto. + ring. + Qed. + + Hint Rewrite mul_div_eq mul_div_eq' using lia : zdiv_to_mod. + Hint Rewrite <- mul_div_eq' using lia : zmod_to_div. + + Ltac prime_bound := match goal with + | [ H : prime ?p |- _ ] => pose proof (prime_ge_2 p H); try omega + end. + + Lemma testbit_low : forall n x i, (0 <= i < n) -> + Z.testbit x i = Z.testbit (Z.land x (Z.ones n)) i. + Proof. + intros. + rewrite Z.land_ones by omega. + symmetry. + apply Z.mod_pow2_bits_low. + omega. + Qed. + + + Lemma testbit_add_shiftl_low : forall i, (0 <= i) -> forall a b n, (i < n) -> + Z.testbit (a + Z.shiftl b n) i = Z.testbit a i. + Proof. + intros. + erewrite Z.testbit_low; eauto. + rewrite Z.land_ones, Z.shiftl_mul_pow2 by omega. + rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 n); omega). + auto using Z.mod_pow2_bits_low. + Qed. + + Lemma mod_div_eq0 : forall a b, 0 < b -> (a mod b) / b = 0. + Proof. + intros. + apply Z.div_small. + auto using Z.mod_pos_bound. + Qed. + Hint Rewrite mod_div_eq0 using lia : zsimplify. + + Lemma shiftr_add_shiftl_high : forall n m a b, 0 <= n <= m -> 0 <= a < 2 ^ n -> + Z.shiftr (a + (Z.shiftl b n)) m = Z.shiftr b (m - n). + Proof. + intros. + rewrite !Z.shiftr_div_pow2, Z.shiftl_mul_pow2 by omega. + replace (2 ^ m) with (2 ^ n * 2 ^ (m - n)) by + (rewrite <-Z.pow_add_r by omega; f_equal; ring). + rewrite <-Z.div_div, Z.div_add, (Z.div_small a) ; try solve + [assumption || apply Z.pow_nonzero || apply Z.pow_pos_nonneg; omega]. + f_equal; ring. + Qed. + + Lemma shiftr_add_shiftl_low : forall n m a b, 0 <= m <= n -> 0 <= a < 2 ^ n -> + Z.shiftr (a + (Z.shiftl b n)) m = Z.shiftr a m + Z.shiftr b (m - n). + Proof. + intros. + rewrite !Z.shiftr_div_pow2, Z.shiftl_mul_pow2, Z.shiftr_mul_pow2 by omega. + replace (2 ^ n) with (2 ^ (n - m) * 2 ^ m) by + (rewrite <-Z.pow_add_r by omega; f_equal; ring). + rewrite Z.mul_assoc, Z.div_add by (apply Z.pow_nonzero; omega). + repeat f_equal; ring. + Qed. + + Lemma testbit_add_shiftl_high : forall i, (0 <= i) -> forall a b n, (0 <= n <= i) -> + 0 <= a < 2 ^ n -> + Z.testbit (a + Z.shiftl b n) i = Z.testbit b (i - n). + Proof. + intros ? ?. + apply natlike_ind with (x := i); intros; try assumption; + (destruct (Z_eq_dec 0 n); [ subst; rewrite Z.pow_0_r in *; + replace a with 0 by omega; f_equal; ring | ]); try omega. + rewrite <-Z.add_1_r at 1. rewrite <-Z.shiftr_spec by assumption. + replace (Z.succ x - n) with (x - (n - 1)) by ring. + rewrite shiftr_add_shiftl_low, <-Z.shiftl_opp_r with (a := b) by omega. + rewrite <-H1 with (a := Z.shiftr a 1); try omega; [ repeat f_equal; ring | ]. + rewrite Z.shiftr_div_pow2 by omega. + split; apply Z.div_pos || apply Z.div_lt_upper_bound; + try solve [rewrite ?Z.pow_1_r; omega]. + rewrite <-Z.pow_add_r by omega. + replace (1 + (n - 1)) with n by ring; omega. + Qed. + + Lemma land_add_land : forall n m a b, (m <= n)%nat -> + Z.land ((Z.land a (Z.ones (Z.of_nat n))) + (Z.shiftl b (Z.of_nat n))) (Z.ones (Z.of_nat m)) = Z.land a (Z.ones (Z.of_nat m)). + Proof. + intros. + rewrite !Z.land_ones by apply Nat2Z.is_nonneg. + rewrite Z.shiftl_mul_pow2 by apply Nat2Z.is_nonneg. + replace (b * 2 ^ Z.of_nat n) with + ((b * 2 ^ Z.of_nat (n - m)) * 2 ^ Z.of_nat m) by + (rewrite (le_plus_minus m n) at 2; try assumption; + rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg; ring). + rewrite Z.mod_add by (pose proof (Z.pow_pos_nonneg 2 (Z.of_nat m)); omega). + symmetry. apply Znumtheory.Zmod_div_mod; try (apply Z.pow_pos_nonneg; omega). + rewrite (le_plus_minus m n) by assumption. + rewrite Nat2Z.inj_add, Z.pow_add_r by apply Nat2Z.is_nonneg. + apply Z.divide_factor_l. + Qed. + + Lemma div_pow2succ : forall n x, (0 <= x) -> + n / 2 ^ Z.succ x = Z.div2 (n / 2 ^ x). + Proof. + intros. + rewrite Z.pow_succ_r, Z.mul_comm by auto. + rewrite <- Z.div_div by (try apply Z.pow_nonzero; omega). + rewrite Zdiv2_div. + reflexivity. + Qed. + + Lemma shiftr_succ : forall n x, + Z.shiftr n (Z.succ x) = Z.shiftr (Z.shiftr n x) 1. + Proof. + intros. + rewrite Z.shiftr_shiftr by omega. + reflexivity. + Qed. + + + Definition shiftl_by n a := Z.shiftl a n. + + Lemma shiftl_by_mul_pow2 : forall n a, 0 <= n -> Z.mul (2 ^ n) a = Z.shiftl_by n a. + Proof. + intros. + unfold Z.shiftl_by. + rewrite Z.shiftl_mul_pow2 by assumption. + apply Z.mul_comm. + Qed. + + Lemma map_shiftl : forall n l, 0 <= n -> map (Z.mul (2 ^ n)) l = map (Z.shiftl_by n) l. + Proof. + intros; induction l; auto using Z.shiftl_by_mul_pow2. + simpl. + rewrite IHl. + f_equal. + apply Z.shiftl_by_mul_pow2. + assumption. + Qed. + + Lemma odd_mod : forall a b, (b <> 0)%Z -> + Z.odd (a mod b) = if Z.odd b then xorb (Z.odd a) (Z.odd (a / b)) else Z.odd a. + Proof. + intros. + rewrite Zmod_eq_full by assumption. + rewrite <-Z.add_opp_r, Z.odd_add, Z.odd_opp, Z.odd_mul. + case_eq (Z.odd b); intros; rewrite ?Bool.andb_true_r, ?Bool.andb_false_r; auto using Bool.xorb_false_r. + Qed. + + Lemma mod_same_pow : forall a b c, 0 <= c <= b -> a ^ b mod a ^ c = 0. + Proof. + intros. + replace b with (b - c + c) by ring. + rewrite Z.pow_add_r by omega. + apply Z_mod_mult. + Qed. + Hint Rewrite mod_same_pow using lia : zsimplify. + + Lemma ones_succ : forall x, (0 <= x) -> Z.ones (Z.succ x) = 2 ^ x + Z.ones x. Proof. unfold Z.ones; intros. @@ -365,14 +378,14 @@ Qed. rewrite Z.pow_succ_r; omega. Qed. - Lemma Z_div_floor : forall a b c, 0 < b -> a < b * (Z.succ c) -> a / b <= c. + Lemma div_floor : forall a b c, 0 < b -> a < b * (Z.succ c) -> a / b <= c. Proof. intros. apply Z.lt_succ_r. apply Z.div_lt_upper_bound; try omega. Qed. - Lemma Z_shiftr_1_r_le : forall a b, a <= b -> + Lemma shiftr_1_r_le : forall a b, a <= b -> Z.shiftr a 1 <= Z.shiftr b 1. Proof. intros. @@ -380,7 +393,7 @@ Qed. apply Z.div_le_mono; omega. Qed. - Lemma Z_ones_pred : forall i, 0 < i -> Z.ones (Z.pred i) = Z.shiftr (Z.ones i) 1. + Lemma ones_pred : forall i, 0 < i -> Z.ones (Z.pred i) = Z.shiftr (Z.ones i) 1. Proof. induction i; [ | | pose proof (Pos2Z.neg_is_neg p) ]; try omega. intros. @@ -394,7 +407,7 @@ Qed. f_equal. omega. Qed. - Lemma Z_shiftr_ones' : forall a n, 0 <= a < 2 ^ n -> forall i, (0 <= i) -> + Lemma shiftr_ones' : forall a n, 0 <= a < 2 ^ n -> forall i, (0 <= i) -> Z.shiftr a i <= Z.ones (n - i) \/ n <= i. Proof. intros until 1. @@ -408,17 +421,17 @@ Qed. left. rewrite shiftr_succ. replace (n - Z.succ x) with (Z.pred (n - x)) by omega. - rewrite Z_ones_pred by omega. - apply Z_shiftr_1_r_le. + rewrite Z.ones_pred by omega. + apply Z.shiftr_1_r_le. assumption. Qed. - Lemma Z_shiftr_ones : forall a n i, 0 <= a < 2 ^ n -> (0 <= i) -> (i <= n) -> + Lemma shiftr_ones : forall a n i, 0 <= a < 2 ^ n -> (0 <= i) -> (i <= n) -> Z.shiftr a i <= Z.ones (n - i) . Proof. intros a n i G G0 G1. destruct (Z_le_lt_eq_dec i n G1). - + destruct (Z_shiftr_ones' a n G i G0); omega. + + destruct (Z.shiftr_ones' a n G i G0); omega. + subst; rewrite Z.sub_diag. destruct (Z_eq_dec a 0). - subst; rewrite Z.shiftr_0_l; reflexivity. @@ -426,7 +439,7 @@ Qed. apply Z.log2_lt_pow2; omega. Qed. - Lemma Z_shiftr_upper_bound : forall a n, 0 <= n -> 0 <= a <= 2 ^ n -> Z.shiftr a n <= 1. + Lemma shiftr_upper_bound : forall a n, 0 <= n -> 0 <= a <= 2 ^ n -> Z.shiftr a n <= 1. Proof. intros a ? ? [a_nonneg a_upper_bound]. apply Z_le_lt_eq_dec in a_upper_bound. @@ -442,439 +455,465 @@ Qed. omega. Qed. -(* prove that combinations of known positive/nonnegative numbers are positive/nonnegative *) -Ltac zero_bounds' := - repeat match goal with - | [ |- 0 <= _ + _] => apply Z.add_nonneg_nonneg - | [ |- 0 <= _ - _] => apply Z.le_0_sub - | [ |- 0 <= _ * _] => apply Z.mul_nonneg_nonneg - | [ |- 0 <= _ / _] => apply Z.div_pos - | [ |- 0 <= _ ^ _ ] => apply Z.pow_nonneg - | [ |- 0 <= Z.shiftr _ _] => apply Z.shiftr_nonneg - | [ |- 0 < _ + _] => try solve [apply Z.add_pos_nonneg; zero_bounds']; - try solve [apply Z.add_nonneg_pos; zero_bounds'] - | [ |- 0 < _ - _] => apply Z.lt_0_sub - | [ |- 0 < _ * _] => apply Z.lt_0_mul; left; split - | [ |- 0 < _ / _] => apply Z.div_str_pos - | [ |- 0 < _ ^ _ ] => apply Z.pow_pos_nonneg - end; try omega; try prime_bound; auto. - -Ltac zero_bounds := try omega; try prime_bound; zero_bounds'. - -Hint Extern 1 => progress zero_bounds : zero_bounds. - -Lemma Z_ones_nonneg : forall i, (0 <= i) -> 0 <= Z.ones i. -Proof. - apply natlike_ind. - + unfold Z.ones. simpl; omega. - + intros. - rewrite Z_ones_succ by assumption. - zero_bounds. -Qed. - -Lemma Z_ones_pos_pos : forall i, (0 < i) -> 0 < Z.ones i. -Proof. - intros. - unfold Z.ones. - rewrite Z.shiftl_1_l. - apply Z.lt_succ_lt_pred. - apply Z.pow_gt_1; omega. -Qed. - -Lemma N_le_1_l : forall p, (1 <= N.pos p)%N. -Proof. - destruct p; cbv; congruence. -Qed. - -Lemma Pos_land_upper_bound_l : forall a b, (Pos.land a b <= N.pos a)%N. -Proof. - induction a; destruct b; intros; try solve [cbv; congruence]; - simpl; specialize (IHa b); case_eq (Pos.land a b); intro; simpl; - try (apply N_le_1_l || apply N.le_0_l); intro land_eq; - rewrite land_eq in *; unfold N.le, N.compare in *; - rewrite ?Pos.compare_xI_xI, ?Pos.compare_xO_xI, ?Pos.compare_xO_xO; - try assumption. - destruct (p ?=a)%positive; cbv; congruence. -Qed. - -Lemma Z_land_upper_bound_l : forall a b, (0 <= a) -> (0 <= b) -> - Z.land a b <= a. -Proof. - intros. - destruct a, b; try solve [exfalso; auto]; try solve [cbv; congruence]. - cbv [Z.land]. - rewrite <-N2Z.inj_pos, <-N2Z.inj_le. - auto using Pos_land_upper_bound_l. -Qed. - -Lemma Z_land_upper_bound_r : forall a b, (0 <= a) -> (0 <= b) -> - Z.land a b <= b. -Proof. - intros. - rewrite Z.land_comm. - auto using Z_land_upper_bound_l. -Qed. - -Lemma Z_le_fold_right_max : forall low l x, (forall y, In y l -> low <= y) -> - In x l -> x <= fold_right Z.max low l. -Proof. - induction l; intros ? lower_bound In_list; [cbv [In] in *; intuition | ]. - simpl. - destruct (in_inv In_list); subst. - + apply Z.le_max_l. - + etransitivity. - - apply IHl; auto; intuition. - - apply Z.le_max_r. -Qed. - -Lemma Z_le_fold_right_max_initial : forall low l, low <= fold_right Z.max low l. -Proof. - induction l; intros; try reflexivity. - etransitivity; [ apply IHl | apply Z.le_max_r ]. -Qed. - -Ltac Zltb_to_Zlt := - repeat match goal with - | [ H : (?x <? ?y) = ?b |- _ ] - => let H' := fresh in - rename H into H'; - pose proof (Zlt_cases x y) as H; - rewrite H' in H; - clear H' - end. - -Ltac Zcompare_to_sgn := - repeat match goal with - | [ H : _ |- _ ] => progress rewrite <- ?Z.sgn_neg_iff, <- ?Z.sgn_pos_iff, <- ?Z.sgn_null_iff in H - | _ => progress rewrite <- ?Z.sgn_neg_iff, <- ?Z.sgn_pos_iff, <- ?Z.sgn_null_iff - end. - -Local Ltac replace_to_const c := - repeat match goal with - | [ H : ?x = ?x |- _ ] => clear H - | [ H : ?x = c, H' : context[?x] |- _ ] => rewrite H in H' - | [ H : c = ?x, H' : context[?x] |- _ ] => rewrite <- H in H' - | [ H : ?x = c |- context[?x] ] => rewrite H - | [ H : c = ?x |- context[?x] ] => rewrite <- H - end. - -Lemma Zlt_div_0 n m : n / m < 0 <-> ((n < 0 < m \/ m < 0 < n) /\ 0 < -(n / m)). -Proof. - Zcompare_to_sgn; rewrite Z.sgn_opp; simpl. - pose proof (Zdiv_sgn n m) as H. - pose proof (Z.sgn_spec (n / m)) as H'. - repeat first [ progress intuition - | progress simpl in * - | congruence - | lia - | progress replace_to_const (-1) - | progress replace_to_const 0 - | progress replace_to_const 1 - | match goal with - | [ x : Z |- _ ] => destruct x - end ]. -Qed. - -Lemma two_times_x_minus_x x : 2 * x - x = x. -Proof. lia. Qed. - -Lemma Zmul_div_le x y z - (Hx : 0 <= x) (Hy : 0 <= y) (Hz : 0 < z) - (Hyz : y <= z) - : x * y / z <= x. -Proof. - transitivity (x * z / z); [ | rewrite Z.div_mul by lia; lia ]. - apply Z_div_le; nia. -Qed. - -Lemma Zdiv_mul_diff a b c - (Ha : 0 <= a) (Hb : 0 < b) (Hc : 0 <= c) - : c * a / b - c * (a / b) <= c. -Proof. - pose proof (Z.mod_pos_bound a b). - etransitivity; [ | apply (Zmul_div_le c (a mod b) b); lia ]. - rewrite (Z_div_mod_eq a b) at 1 by lia. - rewrite Z.mul_add_distr_l. - replace (c * (b * (a / b))) with ((c * (a / b)) * b) by lia. - rewrite Z.div_add_l by lia. - lia. -Qed. - -Lemma Zdiv_mul_le_le a b c - : 0 <= a -> 0 < b -> 0 <= c -> c * (a / b) <= c * a / b <= c * (a / b) + c. -Proof. - pose proof (Zdiv_mul_diff a b c); split; try apply Z.div_mul_le; lia. -Qed. - -Lemma Zdiv_mul_le_le_offset a b c - : 0 <= a -> 0 < b -> 0 <= c -> c * a / b - c <= c * (a / b). -Proof. - pose proof (Zdiv_mul_le_le a b c); lia. -Qed. - -Hint Resolve Zmult_le_compat_r Zmult_le_compat_l Z_div_le Zdiv_mul_le_le_offset Z.add_le_mono Z.sub_le_mono : zarith. - -(** * [Zsimplify_fractions_le] *) -(** The culmination of this series of tactics, - [Zsimplify_fractions_le], will use the fact that [a * (b / c) <= - (a * b) / c], and do some reasoning modulo associativity and - commutativity in [Z] to perform such a reduction. It may leave - over goals if it cannot prove that some denominators are non-zero. - If the rewrite [a * (b / c)] → [(a * b) / c] is safe to do on the - LHS of the goal, this tactic should not turn a solvable goal into - an unsolvable one. - - After running, the tactic does some basic rewriting to simplify - fractions, e.g., that [a * b / b = a]. *) -Ltac Zsplit_sums_step := - match goal with - | [ |- _ + _ <= _ ] - => etransitivity; [ eapply Z.add_le_mono | ] - | [ |- _ - _ <= _ ] - => etransitivity; [ eapply Z.sub_le_mono | ] - end. -Ltac Zsplit_sums := - try (Zsplit_sums_step; [ Zsplit_sums.. | ]). -Ltac Zpre_reorder_fractions_step := - match goal with - | [ |- context[?x / ?y * ?z] ] - => rewrite (Z.mul_comm (x / y) z) - | _ => let LHS := match goal with |- ?LHS <= ?RHS => LHS end in - match LHS with - | context G[?x * (?y / ?z)] - => let G' := context G[(x * y) / z] in - transitivity G' - end - end. -Ltac Zpre_reorder_fractions := - try first [ Zsplit_sums_step; [ Zpre_reorder_fractions.. | ] - | Zpre_reorder_fractions_step; [ .. | Zpre_reorder_fractions ] ]. -Ltac Zsplit_comparison := - match goal with - | [ |- ?x <= ?x ] => reflexivity - | [ H : _ >= _ |- _ ] - => apply Z.ge_le_iff in H - | [ |- ?x * ?y <= ?z * ?w ] - => lazymatch goal with - | [ H : 0 <= x |- _ ] => idtac - | [ H : x < 0 |- _ ] => fail - | _ => destruct (Z_lt_le_dec x 0) - end; - [ .. - | lazymatch goal with - | [ H : 0 <= y |- _ ] => idtac - | [ H : y < 0 |- _ ] => fail - | _ => destruct (Z_lt_le_dec y 0) + Lemma lor_shiftl : forall a b n, 0 <= n -> 0 <= a < 2 ^ n -> + Z.lor a (Z.shiftl b n) = a + (Z.shiftl b n). + Proof. + intros. + apply Z.bits_inj'; intros t ?. + rewrite Z.lor_spec, Z.shiftl_spec by assumption. + destruct (Z_lt_dec t n). + + rewrite testbit_add_shiftl_low by omega. + rewrite Z.testbit_neg_r with (n := t - n) by omega. + apply Bool.orb_false_r. + + rewrite testbit_add_shiftl_high by omega. + replace (Z.testbit a t) with false; [ apply Bool.orb_false_l | ]. + symmetry. + apply Z.testbit_false; try omega. + rewrite Z.div_small; try reflexivity. + split; try eapply Z.lt_le_trans with (m := 2 ^ n); try omega. + apply Z.pow_le_mono_r; omega. + Qed. + + (* prove that combinations of known positive/nonnegative numbers are positive/nonnegative *) + Ltac zero_bounds' := + repeat match goal with + | [ |- 0 <= _ + _] => apply Z.add_nonneg_nonneg + | [ |- 0 <= _ - _] => apply Z.le_0_sub + | [ |- 0 <= _ * _] => apply Z.mul_nonneg_nonneg + | [ |- 0 <= _ / _] => apply Z.div_pos + | [ |- 0 <= _ ^ _ ] => apply Z.pow_nonneg + | [ |- 0 <= Z.shiftr _ _] => apply Z.shiftr_nonneg + | [ |- 0 <= _ mod _] => apply Z.mod_pos_bound + | [ |- 0 < _ + _] => try solve [apply Z.add_pos_nonneg; zero_bounds']; + try solve [apply Z.add_nonneg_pos; zero_bounds'] + | [ |- 0 < _ - _] => apply Z.lt_0_sub + | [ |- 0 < _ * _] => apply Z.lt_0_mul; left; split + | [ |- 0 < _ / _] => apply Z.div_str_pos + | [ |- 0 < _ ^ _ ] => apply Z.pow_pos_nonneg + end; try omega; try prime_bound; auto. + + Ltac zero_bounds := try omega; try prime_bound; zero_bounds'. + + Hint Extern 1 => progress zero_bounds : zero_bounds. + + Lemma ones_nonneg : forall i, (0 <= i) -> 0 <= Z.ones i. + Proof. + apply natlike_ind. + + unfold Z.ones. simpl; omega. + + intros. + rewrite Z.ones_succ by assumption. + zero_bounds. + Qed. + + Lemma ones_pos_pos : forall i, (0 < i) -> 0 < Z.ones i. + Proof. + intros. + unfold Z.ones. + rewrite Z.shiftl_1_l. + apply Z.lt_succ_lt_pred. + apply Z.pow_gt_1; omega. + Qed. + + Lemma N_le_1_l : forall p, (1 <= N.pos p)%N. + Proof. + destruct p; cbv; congruence. + Qed. + + Lemma Pos_land_upper_bound_l : forall a b, (Pos.land a b <= N.pos a)%N. + Proof. + induction a; destruct b; intros; try solve [cbv; congruence]; + simpl; specialize (IHa b); case_eq (Pos.land a b); intro; simpl; + try (apply N_le_1_l || apply N.le_0_l); intro land_eq; + rewrite land_eq in *; unfold N.le, N.compare in *; + rewrite ?Pos.compare_xI_xI, ?Pos.compare_xO_xI, ?Pos.compare_xO_xO; + try assumption. + destruct (p ?=a)%positive; cbv; congruence. + Qed. + + Lemma land_upper_bound_l : forall a b, (0 <= a) -> (0 <= b) -> + Z.land a b <= a. + Proof. + intros. + destruct a, b; try solve [exfalso; auto]; try solve [cbv; congruence]. + cbv [Z.land]. + rewrite <-N2Z.inj_pos, <-N2Z.inj_le. + auto using Pos_land_upper_bound_l. + Qed. + + Lemma land_upper_bound_r : forall a b, (0 <= a) -> (0 <= b) -> + Z.land a b <= b. + Proof. + intros. + rewrite Z.land_comm. + auto using Z.land_upper_bound_l. + Qed. + + Lemma le_fold_right_max : forall low l x, (forall y, In y l -> low <= y) -> + In x l -> x <= fold_right Z.max low l. + Proof. + induction l; intros ? lower_bound In_list; [cbv [In] in *; intuition | ]. + simpl. + destruct (in_inv In_list); subst. + + apply Z.le_max_l. + + etransitivity. + - apply IHl; auto; intuition. + - apply Z.le_max_r. + Qed. + + Lemma le_fold_right_max_initial : forall low l, low <= fold_right Z.max low l. + Proof. + induction l; intros; try reflexivity. + etransitivity; [ apply IHl | apply Z.le_max_r ]. + Qed. + + Ltac ltb_to_lt := + repeat match goal with + | [ H : (?x <? ?y) = ?b |- _ ] + => let H' := fresh in + rename H into H'; + pose proof (Zlt_cases x y) as H; + rewrite H' in H; + clear H' + end. + + Ltac compare_to_sgn := + repeat match goal with + | [ H : _ |- _ ] => progress rewrite <- ?Z.sgn_neg_iff, <- ?Z.sgn_pos_iff, <- ?Z.sgn_null_iff in H + | _ => progress rewrite <- ?Z.sgn_neg_iff, <- ?Z.sgn_pos_iff, <- ?Z.sgn_null_iff + end. + + Local Ltac replace_to_const c := + repeat match goal with + | [ H : ?x = ?x |- _ ] => clear H + | [ H : ?x = c, H' : context[?x] |- _ ] => rewrite H in H' + | [ H : c = ?x, H' : context[?x] |- _ ] => rewrite <- H in H' + | [ H : ?x = c |- context[?x] ] => rewrite H + | [ H : c = ?x |- context[?x] ] => rewrite <- H + end. + + Lemma lt_div_0 n m : n / m < 0 <-> ((n < 0 < m \/ m < 0 < n) /\ 0 < -(n / m)). + Proof. + Z.compare_to_sgn; rewrite Z.sgn_opp; simpl. + pose proof (Zdiv_sgn n m) as H. + pose proof (Z.sgn_spec (n / m)) as H'. + repeat first [ progress intuition + | progress simpl in * + | congruence + | lia + | progress replace_to_const (-1) + | progress replace_to_const 0 + | progress replace_to_const 1 + | match goal with + | [ x : Z |- _ ] => destruct x + end ]. + Qed. + + Lemma two_times_x_minus_x x : 2 * x - x = x. + Proof. lia. Qed. + + Lemma mul_div_le x y z + (Hx : 0 <= x) (Hy : 0 <= y) (Hz : 0 < z) + (Hyz : y <= z) + : x * y / z <= x. + Proof. + transitivity (x * z / z); [ | rewrite Z.div_mul by lia; lia ]. + apply Z_div_le; nia. + Qed. + + Lemma div_mul_diff a b c + (Ha : 0 <= a) (Hb : 0 < b) (Hc : 0 <= c) + : c * a / b - c * (a / b) <= c. + Proof. + pose proof (Z.mod_pos_bound a b). + etransitivity; [ | apply (mul_div_le c (a mod b) b); lia ]. + rewrite (Z_div_mod_eq a b) at 1 by lia. + rewrite Z.mul_add_distr_l. + replace (c * (b * (a / b))) with ((c * (a / b)) * b) by lia. + rewrite Z.div_add_l by lia. + lia. + Qed. + + Lemma div_mul_le_le a b c + : 0 <= a -> 0 < b -> 0 <= c -> c * (a / b) <= c * a / b <= c * (a / b) + c. + Proof. + pose proof (Z.div_mul_diff a b c); split; try apply Z.div_mul_le; lia. + Qed. + + Lemma div_mul_le_le_offset a b c + : 0 <= a -> 0 < b -> 0 <= c -> c * a / b - c <= c * (a / b). + Proof. + pose proof (Z.div_mul_le_le a b c); lia. + Qed. + + Hint Resolve Zmult_le_compat_r Zmult_le_compat_l Z_div_le Z.div_mul_le_le_offset Z.add_le_mono Z.sub_le_mono : zarith. + + (** * [Z.simplify_fractions_le] *) + (** The culmination of this series of tactics, + [Z.simplify_fractions_le], will use the fact that [a * (b / c) <= + (a * b) / c], and do some reasoning modulo associativity and + commutativity in [Z] to perform such a reduction. It may leave + over goals if it cannot prove that some denominators are non-zero. + If the rewrite [a * (b / c)] → [(a * b) / c] is safe to do on the + LHS of the goal, this tactic should not turn a solvable goal into + an unsolvable one. + + After running, the tactic does some basic rewriting to simplify + fractions, e.g., that [a * b / b = a]. *) + Ltac split_sums_step := + match goal with + | [ |- _ + _ <= _ ] + => etransitivity; [ eapply Z.add_le_mono | ] + | [ |- _ - _ <= _ ] + => etransitivity; [ eapply Z.sub_le_mono | ] + end. + Ltac split_sums := + try (split_sums_step; [ split_sums.. | ]). + Ltac pre_reorder_fractions_step := + match goal with + | [ |- context[?x / ?y * ?z] ] + => rewrite (Z.mul_comm (x / y) z) + | _ => let LHS := match goal with |- ?LHS <= ?RHS => LHS end in + match LHS with + | context G[?x * (?y / ?z)] + => let G' := context G[(x * y) / z] in + transitivity G' + end + end. + Ltac pre_reorder_fractions := + try first [ split_sums_step; [ pre_reorder_fractions.. | ] + | pre_reorder_fractions_step; [ .. | pre_reorder_fractions ] ]. + Ltac split_comparison := + match goal with + | [ |- ?x <= ?x ] => reflexivity + | [ H : _ >= _ |- _ ] + => apply Z.ge_le_iff in H + | [ |- ?x * ?y <= ?z * ?w ] + => lazymatch goal with + | [ H : 0 <= x |- _ ] => idtac + | [ H : x < 0 |- _ ] => fail + | _ => destruct (Z_lt_le_dec x 0) end; [ .. - | apply Zmult_le_compat; [ | | assumption | assumption ] ] ] - | [ |- ?x / ?y <= ?z / ?y ] - => lazymatch goal with - | [ H : 0 < y |- _ ] => idtac - | [ H : y <= 0 |- _ ] => fail - | _ => destruct (Z_lt_le_dec 0 y) - end; - [ apply Z_div_le; [ apply gt_lt_symmetry; assumption | ] - | .. ] - | [ |- ?x / ?y <= ?x / ?z ] - => lazymatch goal with - | [ H : 0 <= x |- _ ] => idtac - | [ H : x < 0 |- _ ] => fail - | _ => destruct (Z_lt_le_dec x 0) - end; - [ .. - | lazymatch goal with - | [ H : 0 < z |- _ ] => idtac - | [ H : z <= 0 |- _ ] => fail - | _ => destruct (Z_lt_le_dec 0 z) + | lazymatch goal with + | [ H : 0 <= y |- _ ] => idtac + | [ H : y < 0 |- _ ] => fail + | _ => destruct (Z_lt_le_dec y 0) + end; + [ .. + | apply Zmult_le_compat; [ | | assumption | assumption ] ] ] + | [ |- ?x / ?y <= ?z / ?y ] + => lazymatch goal with + | [ H : 0 < y |- _ ] => idtac + | [ H : y <= 0 |- _ ] => fail + | _ => destruct (Z_lt_le_dec 0 y) end; - [ apply Z.div_le_compat_l; [ assumption | split; [ assumption | ] ] - | .. ] ] - | [ |- _ + _ <= _ + _ ] - => apply Z.add_le_mono - | [ |- _ - _ <= _ - _ ] - => apply Z.sub_le_mono - | [ |- ?x * (?y / ?z) <= (?x * ?y) / ?z ] - => apply Z.div_mul_le - end. -Ltac Zsplit_comparison_fin_step := - match goal with - | _ => assumption - | _ => lia - | _ => progress subst - | [ H : ?n * ?m < 0 |- _ ] - => apply (proj1 (Z.lt_mul_0 n m)) in H; destruct H as [[??]|[??]] - | [ H : ?n / ?m < 0 |- _ ] - => apply (proj1 (Zlt_div_0 n m)) in H; destruct H as [[[??]|[??]]?] - | [ H : (?x^?y) <= ?n < _, H' : ?n < 0 |- _ ] - => assert (0 <= x^y) by zero_bounds; lia - | [ H : (?x^?y) < 0 |- _ ] - => assert (0 <= x^y) by zero_bounds; lia - | [ H : (?x^?y) <= 0 |- _ ] - => let H' := fresh in - assert (H' : 0 <= x^y) by zero_bounds; - assert (x^y = 0) by lia; - clear H H' - | [ H : _^_ = 0 |- _ ] - => apply Z.pow_eq_0_iff in H; destruct H as [?|[??]] - | [ H : 0 <= ?x, H' : ?x - 1 < 0 |- _ ] - => assert (x = 0) by lia; clear H H' - | [ |- ?x <= ?y ] => is_evar x; reflexivity - | [ |- ?x <= ?y ] => is_evar y; reflexivity - end. -Ltac Zsplit_comparison_fin := repeat Zsplit_comparison_fin_step. -Ltac Zsimplify_fractions_step := - match goal with - | _ => rewrite Z.div_mul by (try apply Z.pow_nonzero; zero_bounds) - | [ |- context[?x * ?y / ?x] ] - => rewrite (Z.mul_comm x y) - | [ |- ?x <= ?x ] => reflexivity - end. -Ltac Zsimplify_fractions := repeat Zsimplify_fractions_step. -Ltac Zsimplify_fractions_le := - Zpre_reorder_fractions; - [ repeat Zsplit_comparison; Zsplit_comparison_fin; zero_bounds.. - | Zsimplify_fractions ]. - -Lemma Zlog2_nonneg' n a : n <= 0 -> n <= Z.log2 a. -Proof. - intros; transitivity 0; auto with zarith. -Qed. + [ apply Z_div_le; [ apply Z.gt_lt_iff; assumption | ] + | .. ] + | [ |- ?x / ?y <= ?x / ?z ] + => lazymatch goal with + | [ H : 0 <= x |- _ ] => idtac + | [ H : x < 0 |- _ ] => fail + | _ => destruct (Z_lt_le_dec x 0) + end; + [ .. + | lazymatch goal with + | [ H : 0 < z |- _ ] => idtac + | [ H : z <= 0 |- _ ] => fail + | _ => destruct (Z_lt_le_dec 0 z) + end; + [ apply Z.div_le_compat_l; [ assumption | split; [ assumption | ] ] + | .. ] ] + | [ |- _ + _ <= _ + _ ] + => apply Z.add_le_mono + | [ |- _ - _ <= _ - _ ] + => apply Z.sub_le_mono + | [ |- ?x * (?y / ?z) <= (?x * ?y) / ?z ] + => apply Z.div_mul_le + end. + Ltac split_comparison_fin_step := + match goal with + | _ => assumption + | _ => lia + | _ => progress subst + | [ H : ?n * ?m < 0 |- _ ] + => apply (proj1 (Z.lt_mul_0 n m)) in H; destruct H as [[??]|[??]] + | [ H : ?n / ?m < 0 |- _ ] + => apply (proj1 (lt_div_0 n m)) in H; destruct H as [[[??]|[??]]?] + | [ H : (?x^?y) <= ?n < _, H' : ?n < 0 |- _ ] + => assert (0 <= x^y) by zero_bounds; lia + | [ H : (?x^?y) < 0 |- _ ] + => assert (0 <= x^y) by zero_bounds; lia + | [ H : (?x^?y) <= 0 |- _ ] + => let H' := fresh in + assert (H' : 0 <= x^y) by zero_bounds; + assert (x^y = 0) by lia; + clear H H' + | [ H : _^_ = 0 |- _ ] + => apply Z.pow_eq_0_iff in H; destruct H as [?|[??]] + | [ H : 0 <= ?x, H' : ?x - 1 < 0 |- _ ] + => assert (x = 0) by lia; clear H H' + | [ |- ?x <= ?y ] => is_evar x; reflexivity + | [ |- ?x <= ?y ] => is_evar y; reflexivity + end. + Ltac split_comparison_fin := repeat split_comparison_fin_step. + Ltac simplify_fractions_step := + match goal with + | _ => rewrite Z.div_mul by (try apply Z.pow_nonzero; zero_bounds) + | [ |- context[?x * ?y / ?x] ] + => rewrite (Z.mul_comm x y) + | [ |- ?x <= ?x ] => reflexivity + end. + Ltac simplify_fractions := repeat simplify_fractions_step. + Ltac simplify_fractions_le := + pre_reorder_fractions; + [ repeat split_comparison; split_comparison_fin; zero_bounds.. + | simplify_fractions ]. + + Lemma log2_nonneg' n a : n <= 0 -> n <= Z.log2 a. + Proof. + intros; transitivity 0; auto with zarith. + Qed. -Hint Resolve Zlog2_nonneg' : zarith. + Hint Resolve log2_nonneg' : zarith. -(** We create separate databases for two directions of transformations - involving [Z.log2]; combining them leads to loops. *) -(* for hints that take in hypotheses of type [log2 _], and spit out conclusions of type [_ ^ _] *) -Create HintDb hyp_log2. + (** We create separate databases for two directions of transformations + involving [Z.log2]; combining them leads to loops. *) + (* for hints that take in hypotheses of type [log2 _], and spit out conclusions of type [_ ^ _] *) + Create HintDb hyp_log2. -(* for hints that take in hypotheses of type [_ ^ _], and spit out conclusions of type [log2 _] *) -Create HintDb concl_log2. + (* for hints that take in hypotheses of type [_ ^ _], and spit out conclusions of type [log2 _] *) + Create HintDb concl_log2. -Hint Resolve (fun a b H => proj1 (Z.log2_lt_pow2 a b H)) (fun a b H => proj1 (Z.log2_le_pow2 a b H)) : concl_log2. -Hint Resolve (fun a b H => proj2 (Z.log2_lt_pow2 a b H)) (fun a b H => proj2 (Z.log2_le_pow2 a b H)) : hyp_log2. + Hint Resolve (fun a b H => proj1 (Z.log2_lt_pow2 a b H)) (fun a b H => proj1 (Z.log2_le_pow2 a b H)) : concl_log2. + Hint Resolve (fun a b H => proj2 (Z.log2_lt_pow2 a b H)) (fun a b H => proj2 (Z.log2_le_pow2 a b H)) : hyp_log2. -Lemma Zle_lt_to_log2 x y z : 0 <= z -> 0 < y -> 2^x <= y < 2^z -> x <= Z.log2 y < z. -Proof. - destruct (Z_le_gt_dec 0 x); auto with concl_log2 lia. -Qed. + Lemma le_lt_to_log2 x y z : 0 <= z -> 0 < y -> 2^x <= y < 2^z -> x <= Z.log2 y < z. + Proof. + destruct (Z_le_gt_dec 0 x); auto with concl_log2 lia. + Qed. + + Lemma div_x_y_x x y : 0 < x -> 0 < y -> x / y / x = 1 / y. + Proof. + intros; rewrite Z.div_div, (Z.mul_comm y x), <- Z.div_div, Z.div_same by lia. + reflexivity. + Qed. -Lemma Zdiv_x_y_x x y : 0 < x -> 0 < y -> x / y / x = 1 / y. -Proof. - intros; rewrite Z.div_div, (Z.mul_comm y x), <- Z.div_div, Z.div_same by lia. - reflexivity. -Qed. + Hint Rewrite div_x_y_x using lia : zsimplify. -Hint Rewrite Zdiv_x_y_x using lia : zsimplify. + Lemma mod_opp_l_z_iff a b (H : b <> 0) : a mod b = 0 <-> (-a) mod b = 0. + Proof. + split; intro H'; apply Z.mod_opp_l_z in H'; rewrite ?Z.opp_involutive in H'; assumption. + Qed. -Lemma Zmod_opp_l_z_iff a b (H : b <> 0) : a mod b = 0 <-> (-a) mod b = 0. -Proof. - split; intro H'; apply Z.mod_opp_l_z in H'; rewrite ?Z.opp_involutive in H'; assumption. -Qed. + Lemma opp_eq_0_iff a : -a = 0 <-> a = 0. + Proof. lia. Qed. -Lemma Zopp_eq_0_iff a : -a = 0 <-> a = 0. -Proof. lia. Qed. + Hint Rewrite <- mod_opp_l_z_iff using lia : zsimplify. + Hint Rewrite opp_eq_0_iff : zsimplify. -Hint Rewrite <- Zmod_opp_l_z_iff using lia : zsimplify. -Hint Rewrite Zopp_eq_0_iff : zsimplify. + Lemma sub_pos_bound a b X : 0 <= a < X -> 0 <= b < X -> -X < a - b < X. + Proof. lia. Qed. -Lemma Zsub_pos_bound a b X : 0 <= a < X -> 0 <= b < X -> -X < a - b < X. -Proof. lia. Qed. + Lemma div_opp_l_complete a b (Hb : b <> 0) : -a/b = -(a/b) - (if Z_zerop (a mod b) then 0 else 1). + Proof. + destruct (Z_zerop (a mod b)); autorewrite with zsimplify push_Zopp; reflexivity. + Qed. -Lemma Zdiv_opp_l_complete a b (Hb : b <> 0) : -a/b = -(a/b) - (if Z_zerop (a mod b) then 0 else 1). -Proof. - destruct (Z_zerop (a mod b)); autorewrite with zsimplify push_Zopp; reflexivity. -Qed. + Lemma div_opp_l_complete' a b (Hb : b <> 0) : -(a/b) = -a/b + (if Z_zerop (a mod b) then 0 else 1). + Proof. + destruct (Z_zerop (a mod b)); autorewrite with zsimplify pull_Zopp; lia. + Qed. -Lemma Zdiv_opp_l_complete' a b (Hb : b <> 0) : -(a/b) = -a/b + (if Z_zerop (a mod b) then 0 else 1). -Proof. - destruct (Z_zerop (a mod b)); autorewrite with zsimplify pull_Zopp; lia. -Qed. + Hint Rewrite Z.div_opp_l_complete using lia : pull_Zopp. + Hint Rewrite Z.div_opp_l_complete' using lia : push_Zopp. -Hint Rewrite Zdiv_opp_l_complete using lia : pull_Zopp. -Hint Rewrite Zdiv_opp_l_complete' using lia : push_Zopp. + Lemma div_opp a : a <> 0 -> -a / a = -1. + Proof. + intros; autorewrite with pull_Zopp zsimplify; lia. + Qed. -Lemma Zdiv_opp a : a <> 0 -> -a / a = -1. -Proof. - intros; autorewrite with pull_Zopp zsimplify; lia. -Qed. + Hint Rewrite Z.div_opp using lia : zsimplify. -Hint Rewrite Zdiv_opp using lia : zsimplify. + Lemma div_sub_1_0 x : x > 0 -> (x - 1) / x = 0. + Proof. auto with zarith lia. Qed. -Lemma Zdiv_sub_1_0 x : x > 0 -> (x - 1) / x = 0. -Proof. auto with zarith lia. Qed. + Hint Rewrite div_sub_1_0 using lia : zsimplify. -Hint Rewrite Zdiv_sub_1_0 using lia : zsimplify. + Lemma sub_pos_bound_div a b X : 0 <= a < X -> 0 <= b < X -> -1 <= (a - b) / X <= 0. + Proof. + intros H0 H1; pose proof (Z.sub_pos_bound a b X H0 H1). + assert (Hn : -X <= a - b) by lia. + assert (Hp : a - b <= X - 1) by lia. + split; etransitivity; [ | apply Z_div_le, Hn; lia | apply Z_div_le, Hp; lia | ]; + instantiate; autorewrite with zsimplify; try reflexivity. + Qed. -Lemma Zsub_pos_bound_div a b X : 0 <= a < X -> 0 <= b < X -> -1 <= (a - b) / X <= 0. -Proof. - intros H0 H1; pose proof (Zsub_pos_bound a b X H0 H1). - assert (Hn : -X <= a - b) by lia. - assert (Hp : a - b <= X - 1) by lia. - split; etransitivity; [ | apply Z_div_le, Hn; lia | apply Z_div_le, Hp; lia | ]; - instantiate; autorewrite with zsimplify; try reflexivity. -Qed. + Hint Resolve (fun a b X H0 H1 => proj1 (Z.sub_pos_bound_div a b X H0 H1)) + (fun a b X H0 H1 => proj1 (Z.sub_pos_bound_div a b X H0 H1)) : zarith. -Hint Resolve (fun a b X H0 H1 => proj1 (Zsub_pos_bound_div a b X H0 H1)) - (fun a b X H0 H1 => proj1 (Zsub_pos_bound_div a b X H0 H1)) : zarith. + Lemma sub_pos_bound_div_eq a b X : 0 <= a < X -> 0 <= b < X -> (a - b) / X = if a <? b then -1 else 0. + Proof. + intros H0 H1; pose proof (Z.sub_pos_bound_div a b X H0 H1). + destruct (a <? b) eqn:?; Z.ltb_to_lt. + { cut ((a - b) / X <> 0); [ lia | ]. + autorewrite with zstrip_div; auto with zarith lia. } + { autorewrite with zstrip_div; auto with zarith lia. } + Qed. -Lemma Zsub_pos_bound_div_eq a b X : 0 <= a < X -> 0 <= b < X -> (a - b) / X = if a <? b then -1 else 0. -Proof. - intros H0 H1; pose proof (Zsub_pos_bound_div a b X H0 H1). - destruct (a <? b) eqn:?; Zltb_to_Zlt. - { cut ((a - b) / X <> 0); [ lia | ]. - autorewrite with zstrip_div; auto with zarith lia. } - { autorewrite with zstrip_div; auto with zarith lia. } -Qed. + Lemma add_opp_pos_bound_div_eq a b X : 0 <= a < X -> 0 <= b < X -> (-b + a) / X = if a <? b then -1 else 0. + Proof. + rewrite !(Z.add_comm (-_)), !Z.add_opp_r. + apply Z.sub_pos_bound_div_eq. + Qed. -Lemma Zadd_opp_pos_bound_div_eq a b X : 0 <= a < X -> 0 <= b < X -> (-b + a) / X = if a <? b then -1 else 0. -Proof. - rewrite !(Z.add_comm (-_)), !Z.add_opp_r. - apply Zsub_pos_bound_div_eq. -Qed. + Hint Rewrite Z.sub_pos_bound_div_eq Z.add_opp_pos_bound_div_eq using lia : zstrip_div. -Hint Rewrite Zsub_pos_bound_div_eq Zadd_opp_pos_bound_div_eq using lia : zstrip_div. + Lemma div_small_sym a b : 0 <= a < b -> 0 = a / b. + Proof. intros; symmetry; apply Z.div_small; assumption. Qed. -Lemma Zdiv_small_sym a b : 0 <= a < b -> 0 = a / b. -Proof. intros; symmetry; apply Z.div_small; assumption. Qed. + Lemma mod_small_sym a b : 0 <= a < b -> a = a mod b. + Proof. intros; symmetry; apply Z.mod_small; assumption. Qed. -Lemma Zmod_small_sym a b : 0 <= a < b -> a = a mod b. -Proof. intros; symmetry; apply Z.mod_small; assumption. Qed. + Hint Resolve div_small_sym mod_small_sym : zarith. -Hint Resolve Zdiv_small_sym Zmod_small_sym : zarith. + Lemma div_add' a b c : c <> 0 -> (a + c * b) / c = a / c + b. + Proof. intro; rewrite <- Z.div_add, (Z.mul_comm c); try lia. Qed. -Lemma Zdiv_add' a b c : c <> 0 -> (a + c * b) / c = a / c + b. -Proof. intro; rewrite <- Z.div_add, (Z.mul_comm c); try lia. Qed. + Lemma div_add_l' a b c : b <> 0 -> (b * a + c) / b = a + c / b. + Proof. intro; rewrite <- Z.div_add_l, (Z.mul_comm b); lia. Qed. -Lemma Zdiv_add_l' a b c : b <> 0 -> (b * a + c) / b = a + c / b. -Proof. intro; rewrite <- Z.div_add_l, (Z.mul_comm b); lia. Qed. + Hint Rewrite div_add_l' div_add' using lia : zsimplify. -Hint Rewrite Zdiv_add_l' Zdiv_add' using lia : zsimplify. + Lemma div_add_sub_l a b c d : b <> 0 -> (a * b + c - d) / b = a + (c - d) / b. + Proof. rewrite <- Z.add_sub_assoc; apply Z.div_add_l. Qed. -Lemma Zdiv_add_sub_l a b c d : b <> 0 -> (a * b + c - d) / b = a + (c - d) / b. -Proof. rewrite <- Z.add_sub_assoc; apply Z.div_add_l. Qed. + Lemma div_add_sub_l' a b c d : b <> 0 -> (b * a + c - d) / b = a + (c - d) / b. + Proof. rewrite <- Z.add_sub_assoc; apply Z.div_add_l'. Qed. -Lemma Zdiv_add_sub_l' a b c d : b <> 0 -> (b * a + c - d) / b = a + (c - d) / b. -Proof. rewrite <- Z.add_sub_assoc; apply Zdiv_add_l'. Qed. + Lemma div_add_sub a b c d : c <> 0 -> (a + b * c - d) / c = (a - d) / c + b. + Proof. rewrite (Z.add_comm _ (_ * _)), (Z.add_comm (_ / _)); apply Z.div_add_sub_l. Qed. -Lemma Zdiv_add_sub a b c d : c <> 0 -> (a + b * c - d) / c = (a - d) / c + b. -Proof. rewrite (Z.add_comm _ (_ * _)), (Z.add_comm (_ / _)); apply Zdiv_add_sub_l. Qed. + Lemma div_add_sub' a b c d : c <> 0 -> (a + c * b - d) / c = (a - d) / c + b. + Proof. rewrite (Z.add_comm _ (_ * _)), (Z.add_comm (_ / _)); apply Z.div_add_sub_l'. Qed. -Lemma Zdiv_add_sub' a b c d : c <> 0 -> (a + c * b - d) / c = (a - d) / c + b. -Proof. rewrite (Z.add_comm _ (_ * _)), (Z.add_comm (_ / _)); apply Zdiv_add_sub_l'. Qed. + Hint Rewrite Z.div_add_sub Z.div_add_sub' Z.div_add_sub_l Z.div_add_sub_l' using lia : zsimplify. -Hint Rewrite Zdiv_add_sub Zdiv_add_sub' Zdiv_add_sub_l Zdiv_add_sub_l' using lia : zsimplify. + Lemma div_mul_skip a b k : 0 < b -> 0 < k -> a * b / k / b = a / k. + Proof. + intros; rewrite Z.div_div, (Z.mul_comm k), <- Z.div_div by lia. + autorewrite with zsimplify; reflexivity. + Qed. -Lemma Zdiv_mul_skip a b k : 0 < b -> 0 < k -> a * b / k / b = a / k. -Proof. - intros; rewrite Z.div_div, (Z.mul_comm k), <- Z.div_div by lia. - autorewrite with zsimplify; reflexivity. -Qed. + Lemma div_mul_skip' a b k : 0 < b -> 0 < k -> b * a / k / b = a / k. + Proof. + intros; rewrite Z.div_div, (Z.mul_comm k), <- Z.div_div by lia. + autorewrite with zsimplify; reflexivity. + Qed. -Lemma Zdiv_mul_skip' a b k : 0 < b -> 0 < k -> b * a / k / b = a / k. -Proof. - intros; rewrite Z.div_div, (Z.mul_comm k), <- Z.div_div by lia. - autorewrite with zsimplify; reflexivity. -Qed. + Hint Rewrite Z.div_mul_skip Z.div_mul_skip' using lia : zsimplify. +End Z. -Hint Rewrite Zdiv_mul_skip Zdiv_mul_skip' using lia : zsimplify. +Module Export BoundsTactics. + Ltac prime_bound := Z.prime_bound. + Ltac zero_bounds := Z.zero_bounds. +End BoundsTactics. diff --git a/src/WeierstrassCurve/Pre.v b/src/WeierstrassCurve/Pre.v new file mode 100644 index 000000000..b140e95b5 --- /dev/null +++ b/src/WeierstrassCurve/Pre.v @@ -0,0 +1,56 @@ +Require Import Coq.Classes.Morphisms. Require Coq.Setoids.Setoid. +Require Import Crypto.Algebra. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Notations. + +Local Open Scope core_scope. + +Generalizable All Variables. +Section Pre. + Context {F eq zero one opp add sub mul inv div} `{field F eq zero one opp add sub mul inv div}. + Local Infix "=" := eq. Local Notation "a <> b" := (not (a = b)). + Local Infix "=" := eq : type_scope. Local Notation "a <> b" := (not (a = b)) : type_scope. + Local Notation "0" := zero. Local Notation "1" := one. + Local Infix "+" := add. Local Infix "*" := mul. + Local Infix "-" := sub. Local Infix "/" := div. + Local Notation "- x" := (opp x). + Local Notation "x ^ 2" := (x*x). Local Notation "x ^ 3" := (x*x^2). + Local Notation "'∞'" := unit : type_scope. + Local Notation "'∞'" := (inr tt) : core_scope. + Local Notation "2" := (1+1). Local Notation "3" := (1+2). + Local Notation "( x , y )" := (inl (pair x y)). + + Add Field WeierstrassCurveField : (Field.field_theory_for_stdlib_tactic (T:=F)). + Add Ring WeierstrassCurveRing : (Ring.ring_theory_for_stdlib_tactic (T:=F)). + + Context {a:F}. + Context {b:F}. + + (* the canonical definitions are in Spec *) + Definition onCurve (P:F*F + ∞) := match P with + | (x, y) => y^2 = x^3 + a*x + b + | ∞ => True + end. + Definition unifiedAdd' (P1' P2':F*F + ∞) : F*F + ∞ := + match P1', P2' with + | (x1, y1), (x2, y2) + => if x1 =? x2 then + if y2 =? -y1 then + ∞ + else ((3*x1^2+a)^2 / (2*y1)^2 - x1 - x1, + (2*x1+x1)*(3*x1^2+a) / (2*y1) - (3*x1^2+a)^3/(2*y1)^3-y1) + else + ((y2-y1)^2 / (x2-x1)^2 - x1 - x2, + (2*x1+x2)*(y2-y1) / (x2-x1) - (y2-y1)^3 / (x2-x1)^3 - y1) + | ∞, ∞ => ∞ + | ∞, _ => P2' + | _, ∞ => P1' + end. + + Lemma unifiedAdd'_onCurve : forall P1 P2, + onCurve P1 -> onCurve P2 -> onCurve (unifiedAdd' P1 P2). + Proof. + unfold onCurve, unifiedAdd'; intros [[x1 y1]|] [[x2 y2]|] H1 H2; + break_match; trivial; setoid_subst_rel eq; only_two_square_roots; super_nsatz. + Qed. +End Pre. |