From 0697253755e9816f5599fbf77c04b5f3db795e16 Mon Sep 17 00:00:00 2001 From: jadep Date: Sun, 14 May 2017 15:53:46 -0400 Subject: make freeze use the correct versions of add_get_carry and zselect --- src/Arithmetic/Core.v | 78 ++++++++++++++++++--------------------------------- 1 file changed, 28 insertions(+), 50 deletions(-) (limited to 'src/Arithmetic/Core.v') diff --git a/src/Arithmetic/Core.v b/src/Arithmetic/Core.v index 06d0409e7..f294c06fc 100644 --- a/src/Arithmetic/Core.v +++ b/src/Arithmetic/Core.v @@ -248,6 +248,7 @@ Require Import Crypto.Algebra.Nsatz. Require Import Crypto.Util.Decidable Crypto.Util.LetIn. Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma. Require Import Crypto.Util.CPSUtil Crypto.Util.Prod. +Require Import Crypto.Util.ZUtil.Zselect. Require Import Crypto.Arithmetic.PrimeFieldTheorems. Require Import Crypto.Util.Tactics.BreakMatch. Require Import Crypto.Util.Tactics.UniquePose. @@ -833,39 +834,43 @@ Module B. End EvalHelpers. Section Select. - Context {weight : nat -> Z} - {select_single : Z -> Z -> Z} - {select_single_correct : forall cond x, - select_single cond x = if dec (cond = 0) then 0 else x} - . + Context {weight : nat -> Z}. - Definition select_cps {n} cond (p : tuple Z n) {T} (f:_->T) := - Tuple.map_cps (select_single cond) p f. + Definition select_cps {n} (mask cond:Z) (p:tuple Z n) + {T} (f:tuple Z n->T) := + dlet t := Z.zselect cond 0 mask in Tuple.map_cps (runtime_and t) p f. - Definition select {n} cond p := @select_cps n cond p _ id. - Lemma select_id {n} cond p T f : - @select_cps n cond p T f = f (select cond p). + Definition select {n} mask cond p := @select_cps n mask cond p _ id. + Lemma select_id {n} mask cond p T f : + @select_cps n mask cond p T f = f (select mask cond p). Proof. - cbv [select_cps select]. autorewrite with uncps push_id. - reflexivity. + cbv [select select_cps Let_In]; autorewrite with uncps push_id; + reflexivity. Qed. Hint Opaque select : uncps. - Hint Rewrite @select_id : uncps. - Lemma eval_select {n} cond p : - eval weight (@select n cond p) = if dec (cond = 0) then 0 else eval weight p. + Lemma map_and_0 {n} (p:tuple Z n) : Tuple.map (Z.land 0) p = zeros n. Proof. - cbv [select select_cps]; autorewrite with uncps push_id. - induction n; [destruct p|]. - { break_match; reflexivity. } - { rewrite (Tuple.subst_left_append p). - rewrite Tuple.map_left_append, !eval_left_append. - rewrite select_single_correct, IHn. - break_match; ring. } - Qed. Hint Rewrite @eval_select : push_basesystem_eval. + induction n; [destruct p; reflexivity | ]. + rewrite (Tuple.subst_append p), Tuple.map_append, Z.land_0_l, IHn. + reflexivity. + Qed. + + Lemma eval_select {n} mask cond x (H:Tuple.map (Z.land mask) x = x) : + B.Positional.eval weight (@select n mask cond x) = + if dec (cond = 0) then 0 else B.Positional.eval weight x. + Proof. + cbv [select select_cps Let_In]. + autorewrite with uncps push_id. + rewrite Z.zselect_correct; break_match. + { rewrite map_and_0. apply B.Positional.eval_zeros. } + { change runtime_and with Z.land. rewrite H; reflexivity. } + Qed. + End Select. End Positional. + Hint Unfold Positional.add_cps Positional.mul_cps @@ -942,33 +947,6 @@ Section DivMod. Qed. End DivMod. -Section ZSelect. - - Definition mask width cond := - if dec (cond = 0) then 0 else Z.ones width. - - Definition zselect bitwidth (cond x : Z) : Z := - if (dec (x <= 0)) - then (if dec (cond = 0) then 0 else x) - else (let width := Z.max (Z.log2 x + 1) bitwidth in - dlet t := mask width cond in x &' t). - - Lemma zselect_correct bw cond x : - zselect bw cond x = if dec (cond = 0) then 0 else x. - Proof. - cbv [zselect mask Let_In]; break_match; - rewrite ?Z.land_0_r; try reflexivity; [ ]. - pose proof (Z.log2_nonneg x). - pose proof (Z.log2_spec x) as Hlog2. - rewrite <-Z.add_1_r in Hlog2. - apply Z.max_case_strong; intros; rewrite Z.land_ones by omega; - apply Z.mod_small; split; try omega. - apply Z.lt_le_trans with (m:=2 ^ (Z.log2 x + 1)); - [|apply Z.pow_le_mono_r]; omega. - Qed. - -End ZSelect. - Import B. Ltac basesystem_partial_evaluation_RHS := -- cgit v1.2.3