From de3ec0210ea1d40e2e796591c9a192711e79a03f Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 2 May 2018 09:55:42 +0200 Subject: Move straightline and prefancy stuff above barrett reduction --- src/Experiments/SimplyTypedArithmetic.v | 1136 +++++++++++++++---------------- 1 file changed, 567 insertions(+), 569 deletions(-) (limited to 'src/Experiments/SimplyTypedArithmetic.v') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 9618588d4..a1fec63ff 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -7917,231 +7917,600 @@ fun var : type -> Type => End X25519_32. *) -Module BarrettReduction. - (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) - Section Generic. - Context {T} (rep : T -> Z -> Prop) - (k : Z) (k_pos : 0 < k) - (low : T -> Z) - (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) - (shiftr : T -> Z -> T) - (shiftr_correct : forall a x n, - rep a x -> - 0 <= n <= k -> - rep (shiftr a n) (x / 2 ^ n)) - (mul_high : T -> T -> Z -> T) - (mul_high_correct : forall a b x y x0y1, - rep a x -> - rep b y -> - 2 ^ k <= x < 2^(k+1) -> - 0 <= y < 2^(k+1) -> - x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> - rep (mul_high a b x0y1) (x * y / 2 ^ k)) - (mul : Z -> Z -> T) - (mul_correct : forall x y, - 0 <= x < 2^k -> - 0 <= y < 2^k -> - rep (mul x y) (x * y)) - (sub : T -> T -> T) - (sub_correct : forall a b x y, - rep a x -> - rep b y -> - 0 <= x - y < 2^k * 2^k -> - rep (sub a b) (x - y)) - (cond_sub1 : T -> Z -> Z) - (cond_sub1_correct : forall a x y, - rep a x -> - 0 <= x < 2 * y -> - 0 <= y < 2 ^ k -> - cond_sub1 a y = if (x Z -> Z) - (cond_sub2_correct : forall x y, cond_sub2 x y = if (x Type}. + Context {dummy_var : forall t, var t}. - Definition qt := - dlet_nd muSelect := muSelect in (* makes sure muSelect is not inlined in the output *) - dlet_nd q1 := shiftr xt (k - 1) in - dlet_nd twoq := mul_high mut q1 muSelect in - shiftr twoq 1. - Definition reduce := - dlet_nd r2 := mul (low qt) M in - dlet_nd r := sub xt r2 in - dlet_nd q3 := cond_sub1 r M in - cond_sub2 q3 M. + Let uexpr t := @Uncurried.expr.expr ident.ident var t. - Lemma looser_bound : M * 2 ^ k < 2 ^ (2*k). - Proof. clear -M_range M_nz x_range k_pos; rewrite <-Z.add_diag, Z.pow_add_r; nia. Qed. + Section with_ident. + Context {ident : type.type -> type.type -> Type}. + Inductive expr : type.type -> Type := + | Scalar {t} : scalar t -> expr t + | LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d + | LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d + with scalar : type.type -> Type := + | Var t : var t -> scalar t + | TT : scalar (type.type_primitive type.unit) + | Pair {a b} : scalar a -> scalar b -> scalar (a * b) + | Cast : zrange -> scalar type.Z -> scalar type.Z + | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) + | Fst {a b} : scalar (a * b) -> scalar a + | Snd {a b} : scalar (a * b) -> scalar b + | Shiftr : Z -> scalar type.Z -> scalar type.Z + | Shiftl : Z -> scalar type.Z -> scalar type.Z + | Land : Z -> scalar type.Z -> scalar type.Z + | Primitive {t} : type.interp (type.type_primitive t) -> scalar t + . + End with_ident. - Lemma pow_2k_eq : 2 ^ (2*k) = 2 ^ (k - 1) * 2 ^ (k + 1). - Proof. clear -k_pos; rewrite <-Z.pow_add_r by omega. f_equal; ring. Qed. + Fixpoint dummy t : @expr ident.ident t := Scalar (Var t (dummy_var t)). - Lemma mu_bounds : 2 ^ k <= mu < 2^(k+1). - Proof. - pose proof looser_bound. - subst mu. split. - { apply Z.div_le_lower_bound; omega. } - { apply Z.div_lt_upper_bound; try omega. - rewrite pow_2k_eq; apply Z.mul_lt_mono_pos_r; auto with zarith. } - Qed. + Definition scalar_of_uncurried_ident {s d} (idc : ident.ident s d) + : scalar s -> option (scalar d) := + match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with + | ident.Z.cast r => fun args => Some (Cast r args) + | ident.Z.cast2 r => fun args => Some (Cast2 r args) + | @ident.fst A B => fun args => Some (Fst args) + | @ident.snd A B => fun args => Some (Snd args) + | ident.Z.shiftr n => fun args => Some (Shiftr n args) + | ident.Z.shiftl n => fun args => Some (Shiftl n args) + | ident.Z.land n => fun args => Some (Land n args) + | @ident.primitive p x => fun _ => Some (Primitive x) + | _ => fun _ => None + end. - Lemma shiftr_x_bounds : 0 <= x / 2 ^ (k - 1) < 2^(k+1). - Proof. - pose proof looser_bound. - split; [ solve [Z.zero_bounds] | ]. - apply Z.div_lt_upper_bound; auto with zarith. - rewrite <-pow_2k_eq. omega. - Qed. - Hint Resolve shiftr_x_bounds. + Fixpoint scalar_of_uncurried {t} (e : uexpr t) : option (scalar t) := + match e in Uncurried.expr.expr t return option (scalar t) with + | expr.Var t v as e => Some (Var t v) + | expr.TT as e => Some TT + | expr.Pair A B a b + => match scalar_of_uncurried a, scalar_of_uncurried b with + | Some x, Some y => Some (Pair x y) + | _, _ => None + end + | expr.AppIdent _ _ idc args + => match scalar_of_uncurried args with + | Some x => scalar_of_uncurried_ident idc x + | None => None + end + | _ => None + end. - Ltac solve_rep := eauto using shiftr_correct, mul_high_correct, mul_correct, sub_correct with omega. + Fixpoint range_type t : Type := + match t with + | type.type_primitive type.Z => zrange + | type.prod x y => range_type x * range_type y + | _ => unit + end. - Let q := mu * (x / 2 ^ (k - 1)) / 2 ^ (k + 1). + Definition invert_cast {t} (e : uexpr t) + : option (range_type t * uexpr t) := + match invert_AppIdent e with + | Some (existT s (idc, x)) => + (match idc in ident.ident s t return uexpr s -> option (range_type t * uexpr t) with + | ident.Z.cast r => fun x => Some (r, x) + | ident.Z.cast2 r => fun x => Some (r, x) + | _ => fun _ => None + end) x + | None => None + end. - Lemma q_correct : rep qt q . - Proof. - pose proof mu_bounds. cbv [qt]; subst q. - rewrite Z.pow_add_r, <-Z.div_div by Z.zero_bounds. - solve_rep. - Qed. - Hint Resolve q_correct. + (* ident.Let_In @@ (cast r x) => r, x *) + Definition invert_LetInCast {tx tC} (args : uexpr (tx * (tx -> tC))) + : option (range_type tx * uexpr tx * uexpr (tx -> tC)) := + match invert_Pair args with + | Some (x, e) => + match invert_cast x with + | Some (r, x') => Some (r, x', e) + | None => None + end + | None => None + end. - Lemma x_mod_small : x mod 2 ^ (k - 1) <= M. - Proof. transitivity (2 ^ (k - 1)); auto with zarith. Qed. - Hint Resolve x_mod_small. + Definition invert_LetInAppIdent {tx tC} (args : uexpr (tx * (tx -> tC))) + : option { s : type.type & (range_type tx * ident.ident s tx * scalar s * (var tx -> uexpr tC))%type } := + match invert_LetInCast args with + | Some (r, x, e) => + match invert_AppIdent x with + | Some (existT s idc_x') => + match scalar_of_uncurried (snd idc_x') with + | Some x'' => + match invert_Abs e with + | Some k => Some (existT _ s (r, fst idc_x', x'', k)) + | None => None + end + | None => None + end + | None => None + end + | None => None + end. + + Definition mk_LetInAppIdent {s d t} (default : expr t) + : range_type d -> ident.ident s d -> scalar s -> (var d -> expr t) -> expr t := + match d as d0 return range_type d0 -> ident.ident s d0 -> scalar s -> (var d0 -> expr t) -> expr t with + | type.type_primitive type.Z => + fun r idc x k => @LetInAppIdentZ ident.ident s t r idc x k + | type.prod type.Z type.Z => + fun r idc x k => @LetInAppIdentZZ ident.ident s t r idc x k + | _ => fun _ _ _ _ => default + end. - Lemma q_bounds : 0 <= q < 2 ^ k. - Proof. - pose proof looser_bound. pose proof x_mod_small. pose proof mu_bounds. - split; subst q; [ solve [Z.zero_bounds] | ]. - edestruct q_nice_strong with (n:=M) as [? Hqnice]; - try rewrite Hqnice; auto; try omega; [ ]. - apply Z.le_lt_trans with (m:= x / M). - { break_match; omega. } - { apply Z.div_lt_upper_bound; omega. } - Qed. + Definition of_uncurried_step {t} (e : uexpr t) + (of_uncurried : forall {t}, uexpr t -> expr t) + : expr t -> expr t := + match e in Uncurried.expr.expr t return expr t -> expr t with + | AppIdent s d idc args => + (match idc in ident.ident s d return uexpr s -> expr d -> expr d with + | ident.Let_In tx tC => + fun args default => + match invert_LetInAppIdent args return expr tC with + | Some (existT s (r, idc, x, k)) => + @mk_LetInAppIdent s tx tC default r idc x (fun y : var tx => of_uncurried (k y)) + | None => default + end + | ident.Z.cast r => + fun (args : uexpr _) default => + match invert_AppIdent args with + | Some (existT s idc_x') => + match scalar_of_uncurried (snd idc_x') with + | Some x'' => + @mk_LetInAppIdent s type.Z type.Z default r (fst idc_x') x'' (fun y => Scalar (Var _ y)) + | None => default + end + | None => default + end + | ident.Z.cast2 r => + fun (args : uexpr _) default => + match invert_AppIdent args with + | Some (existT s idc_x') => + match scalar_of_uncurried (snd idc_x') with + | Some x'' => + @mk_LetInAppIdent s (type.Z*type.Z) (type.Z*type.Z) default r (fst idc_x') x'' (fun y => Scalar (Var _ y)) + | None => default + end + | None => default + end + | _ => fun _ default => default + end) args + | _ as e => + (fun default => + match scalar_of_uncurried e with + | Some s => Scalar s + | None => default + end) + end. - Lemma two_conditional_subtracts : - forall a x, - rep a x -> - 0 <= x < 2 * M -> - cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. - Proof. - intros. - erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). - break_match; Z.ltb_to_lt; try lia; discriminate. - Qed. - - Lemma r_bounds : 0 <= x - q * M < 2 * M. - Proof. - pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. - subst q mu; split. - { Z.zero_bounds. apply qn_small; omega. } - { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } - Qed. + (* TODO : uses fuel; ideally want a cleaner termination proof *) + Fixpoint of_uncurried (fuel : nat) {t} (e : uexpr t) + : expr t := + match fuel with + | S fuel' => of_uncurried_step e (@of_uncurried fuel') (dummy t) + | O => dummy t + end. - Lemma reduce_correct : reduce = x mod M. - Proof. - pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. - assert (2 * M < 2^k * 2^k) by nia. - rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). - cbv [reduce Let_In]. - erewrite low_correct by eauto. Z.rewrite_mod_small. - erewrite two_conditional_subtracts by solve_rep. - rewrite !cond_sub2_correct. - subst q; reflexivity. - Qed. - End Generic. + End with_var. + End expr. - Section BarrettReduction. - Context (k : Z) (k_bound : 2 <= k). - Context (M muLow : Z). - Context (M_pos : 0 < M) - (muLow_eq : muLow + 2^k = 2^(2*k) / M) - (muLow_bounds : 0 <= muLow < 2^k) - (M_bound1 : 2 ^ (k - 1) < M < 2^k) - (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). + Fixpoint depth {var t} (dummy_var : forall t, var t) (e : @Uncurried.expr.expr ident.ident var t) : nat := + match e with + | Uncurried.expr.Var _ _ => O + | Uncurried.expr.TT => O + | Uncurried.expr.AppIdent _ _ idc args => S (depth dummy_var args) + | Uncurried.expr.App _ _ f x => S (Nat.max (depth dummy_var f) (depth dummy_var x)) + | Uncurried.expr.Pair _ _ x y => S (Nat.max (depth dummy_var x) (depth dummy_var y)) + | Uncurried.expr.Abs _ _ f => S (depth dummy_var (f (dummy_var _))) + end. - Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). - Context (nout : nat) (Hnout : nout = 2%nat). - Let w := weight k 1. - Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. - Let props : @weight_properties w := wprops k 1 k_range. + (* TODO : Can I avoid having dummy_var appear here? *) + Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_var : expr.expr d + := + match invert_Abs (e var) with + | Some f => expr.of_uncurried (dummy_var:=dummy_var) (depth dummy_var (f x)) (f x) + | None => expr.dummy (dummy_var:=dummy_var) d + end. - Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. +End Straightline. - Definition low (t : list Z) : Z := nth_default 0 t 0. - Definition high (t : list Z) : Z := nth_default 0 t 1. - Definition represents (t : list Z) (x : Z) := - t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. +Module StraightlineTest. + Definition test : Expr (type.Z -> type.Z) := + fun var => + Abs + (fun (x : var type.Z) => + AppIdent (var:=var) ident.Let_In + (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) + (Abs (fun x : var type.Z => expr.Var x)))). - Lemma represents_eq t x : - represents t x -> t = [x mod 2^k; x / 2^k]. - Proof. cbv [represents]; tauto. Qed. + Eval vm_compute in (Straightline.of_Expr test). - Lemma represents_length t x : represents t x -> length t = 2%nat. - Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. + Definition test_mul : Expr (type.Z -> type.Z) := + fun var => + Abs + (fun (x : var type.Z) => + AppIdent (var:=var) ident.Let_In + (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) + (Abs (fun y : var type.Z => + AppIdent ident.Let_In + (Pair (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent ident.Z.mul (Pair (AppIdent (@ident.primitive type.Z 12) TT) (Var y)))) + (Abs (fun z : var type.Z => (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (ident.Z.shiftr 3) (Var z))))) + ))))). + Eval vm_compute in (Straightline.of_Expr test_mul). +End StraightlineTest. - Lemma represents_low t x : - represents t x -> low t = x mod 2^k. - Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. +(* Convert straightline code to code that uses only a certain set of identifiers *) +Module PreFancy. + Section with_var. + Import Straightline.expr. + Context {var : type -> Type} (dummy_var : forall t, var t) (log2wordsize : Z) + (constant_to_scalar : forall ident, Z -> option (@scalar var ident type.Z)). + Local Notation Z := (type.type_primitive type.Z). + Let wordsize := 2 ^ log2wordsize. + Let half_bits := log2wordsize / 2. + Let half_wordsize := 2 ^ half_bits. - Lemma represents_high t x : - represents t x -> high t = x / 2^k. - Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + Inductive ident : type -> type -> Type := + | add : ident (Z * Z) (Z * Z) + | addc : ident (Z * Z * Z) (Z * Z) + | mulll : ident (Z * Z) Z + | mullh : ident (Z * Z) Z + | mulhl : ident (Z * Z) Z + | mulhh : ident (Z * Z) Z + | sub : ident (Z * Z) (Z * Z) + | shiftr : BinInt.Z -> ident Z Z + | shiftl : BinInt.Z -> ident Z Z + | sel : ident (Z * Z * Z) Z + | addm : ident (Z * Z * Z) Z + . + Let dummy t : @expr var ident t := Scalar (Var _ (dummy_var t)). - Lemma represents_low_range t x : - represents t x -> 0 <= x mod 2^k < 2^k. - Proof. auto with zarith. Qed. + Definition invert_lower' {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Cast r (Land n x) => + if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? 2^half_bits-1) + then Some x + else None + | _ => None + end. - Lemma represents_high_range t x : - represents t x -> 0 <= x / 2^k < 2^k. - Proof. - destruct 1 as [? [? ?] ]; intros. - auto using Z.div_lt_upper_bound with zarith. - Qed. - Hint Resolve represents_length represents_low_range represents_high_range. + Definition invert_upper' {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Cast r (Shiftr n x) => + if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? half_bits) + then Some x + else None + | _ => None + end. - Lemma represents_range t x : - represents t x -> 0 <= x < 2^k*2^k. - Proof. cbv [represents]; tauto. Qed. + Definition invert_lower {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Primitive type.Z x => + match constant_to_scalar ident x with + | Some y => invert_lower' y + | None => None + end + | _ => invert_lower' e + end. - Lemma represents_id x : - 0 <= x < 2^k * 2^k -> - represents [x mod 2^k; x / 2^k] x. - Proof. - intros; cbv [represents]; autorewrite with cancel_pair. - Z.rewrite_mod_small; tauto. - Qed. + Definition invert_upper {t} (e : @scalar var ident t) : + option (@scalar var ident Z) := + match e in scalar t return option (@scalar var ident Z) with + | Primitive type.Z x => + match constant_to_scalar ident x with + | Some y => invert_upper' y + | None => None + end + | _ => invert_upper' e + end. - Local Ltac push_rep := - repeat match goal with - | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) - | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) - | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption - | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption - end. - Definition shiftr (t : list Z) (n : Z) : list Z := - [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. + Definition of_straightline_ident {s d} (idc : ident.ident s d) + : forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t := + match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with + | ident.Z.add_get_carry_concrete w => + fun t r x f => + if w =? wordsize + then LetInAppIdentZZ r add x f + else dummy _ + | ident.Z.add_with_get_carry_concrete w => + fun t r x f => + if w =? wordsize + then LetInAppIdentZZ r addc x f + else dummy _ + | ident.Z.sub_get_borrow_concrete w => + fun t r x f => + if w =? wordsize + then LetInAppIdentZZ r sub x f + else dummy _ + | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n) + | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n) + | ident.Z.zselect => fun _ r => LetInAppIdentZ r sel + | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm + | ident.Z.mul => + fun t r x f => + match x return expr t with + | Pair _ _ x0 x1 => + match invert_lower x0, invert_lower x1 with + | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f + | Some y0, None => + match invert_upper x1 with + | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f + | None => dummy _ + end + | None, Some y1 => + match invert_upper x0 with + | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f + | None => dummy _ + end + | None, None => + match invert_upper x0, invert_upper x1 with + | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f + | _,_ => dummy _ + end + end + | _ => dummy _ + end + | _ => fun t _ _ _ => dummy t + end. - Lemma shiftr_represents a i x : - represents a x -> - 0 <= i <= k -> - represents (shiftr a i) (x / 2 ^ i). - Proof. - cbv [shiftr]; intros; push_rep. - match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. - assert (0 < 2 ^ i) by auto with zarith. - assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. - assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by + Fixpoint of_straightline_scalar {t} (s : @scalar var ident.ident t) + : @scalar var ident t := + match s with + | Var _ v => Var _ v + | TT => TT + | Pair _ _ x y => Pair (of_straightline_scalar x) (of_straightline_scalar y) + | Cast r x => Cast r (of_straightline_scalar x) + | Cast2 r x => Cast2 r (of_straightline_scalar x) + | Fst _ _ x => Fst (of_straightline_scalar x) + | Snd _ _ x => Snd (of_straightline_scalar x) + | Shiftr n x => Shiftr n (of_straightline_scalar x) + | Shiftl n x => Shiftl n (of_straightline_scalar x) + | Land n x => Land n (of_straightline_scalar x) + | Primitive _ x => Primitive x + end. + + Fixpoint of_straightline {t} (e : @expr var ident.ident t) + : @expr var ident t := + match e with + | Scalar _ s => Scalar (of_straightline_scalar s) + | LetInAppIdentZ _ t r idc x f => + of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) + | LetInAppIdentZZ _ t r idc x f => + of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) + end. + End with_var. +End PreFancy. + +Module BarrettReduction. + (* TODO : generalize to multi-word and operate on (list Z) instead of T; maybe stop taking ops as context variables *) + Section Generic. + Context {T} (rep : T -> Z -> Prop) + (k : Z) (k_pos : 0 < k) + (low : T -> Z) + (low_correct : forall a x, rep a x -> low a = x mod 2 ^ k) + (shiftr : T -> Z -> T) + (shiftr_correct : forall a x n, + rep a x -> + 0 <= n <= k -> + rep (shiftr a n) (x / 2 ^ n)) + (mul_high : T -> T -> Z -> T) + (mul_high_correct : forall a b x y x0y1, + rep a x -> + rep b y -> + 2 ^ k <= x < 2^(k+1) -> + 0 <= y < 2^(k+1) -> + x0y1 = x mod 2 ^ k * (y / 2 ^ k) -> + rep (mul_high a b x0y1) (x * y / 2 ^ k)) + (mul : Z -> Z -> T) + (mul_correct : forall x y, + 0 <= x < 2^k -> + 0 <= y < 2^k -> + rep (mul x y) (x * y)) + (sub : T -> T -> T) + (sub_correct : forall a b x y, + rep a x -> + rep b y -> + 0 <= x - y < 2^k * 2^k -> + rep (sub a b) (x - y)) + (cond_sub1 : T -> Z -> Z) + (cond_sub1_correct : forall a x y, + rep a x -> + 0 <= x < 2 * y -> + 0 <= y < 2 ^ k -> + cond_sub1 a y = if (x Z -> Z) + (cond_sub2_correct : forall x y, cond_sub2 x y = if (x + 0 <= x < 2 * M -> + cond_sub2 (cond_sub1 a M) M = cond_sub2 (cond_sub2 x M) M. + Proof. + intros. + erewrite !cond_sub2_correct, !cond_sub1_correct by (eassumption || omega). + break_match; Z.ltb_to_lt; try lia; discriminate. + Qed. + + Lemma r_bounds : 0 <= x - q * M < 2 * M. + Proof. + pose proof looser_bound. pose proof q_bounds. pose proof x_mod_small. + subst q mu; split. + { Z.zero_bounds. apply qn_small; omega. } + { apply r_small_strong; rewrite ?Z.pow_1_r; auto; omega. } + Qed. + + Lemma reduce_correct : reduce = x mod M. + Proof. + pose proof looser_bound. pose proof r_bounds. pose proof q_bounds. + assert (2 * M < 2^k * 2^k) by nia. + rewrite barrett_reduction_small with (k:=k) (m:=mu) (offset:=1) (b:=2) by (auto; omega). + cbv [reduce Let_In]. + erewrite low_correct by eauto. Z.rewrite_mod_small. + erewrite two_conditional_subtracts by solve_rep. + rewrite !cond_sub2_correct. + subst q; reflexivity. + Qed. + End Generic. + + Section BarrettReduction. + Context (k : Z) (k_bound : 2 <= k). + Context (M muLow : Z). + Context (M_pos : 0 < M) + (muLow_eq : muLow + 2^k = 2^(2*k) / M) + (muLow_bounds : 0 <= muLow < 2^k) + (M_bound1 : 2 ^ (k - 1) < M < 2^k) + (M_bound2: 2 * (2 ^ (2 * k) mod M) <= 2 ^ (k + 1) - (muLow + 2^k)). + + Context (n:nat) (Hn_nz: n <> 0%nat) (n_le_k : Z.of_nat n <= k). + Context (nout : nat) (Hnout : nout = 2%nat). + Let w := weight k 1. + Local Lemma k_range : 0 < 1 <= k. Proof. omega. Qed. + Let props : @weight_properties w := wprops k 1 k_range. + + Hint Rewrite Positional.eval_nil Positional.eval_snoc : push_eval. + + Definition low (t : list Z) : Z := nth_default 0 t 0. + Definition high (t : list Z) : Z := nth_default 0 t 1. + Definition represents (t : list Z) (x : Z) := + t = [x mod 2^k; x / 2^k] /\ 0 <= x < 2^k * 2^k. + + Lemma represents_eq t x : + represents t x -> t = [x mod 2^k; x / 2^k]. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_length t x : represents t x -> length t = 2%nat. + Proof. cbv [represents]; intuition. subst t; reflexivity. Qed. + + Lemma represents_low t x : + represents t x -> low t = x mod 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_high t x : + represents t x -> high t = x / 2^k. + Proof. cbv [represents]; intros; rewrite (represents_eq t x) by auto; reflexivity. Qed. + + Lemma represents_low_range t x : + represents t x -> 0 <= x mod 2^k < 2^k. + Proof. auto with zarith. Qed. + + Lemma represents_high_range t x : + represents t x -> 0 <= x / 2^k < 2^k. + Proof. + destruct 1 as [? [? ?] ]; intros. + auto using Z.div_lt_upper_bound with zarith. + Qed. + Hint Resolve represents_length represents_low_range represents_high_range. + + Lemma represents_range t x : + represents t x -> 0 <= x < 2^k*2^k. + Proof. cbv [represents]; tauto. Qed. + + Lemma represents_id x : + 0 <= x < 2^k * 2^k -> + represents [x mod 2^k; x / 2^k] x. + Proof. + intros; cbv [represents]; autorewrite with cancel_pair. + Z.rewrite_mod_small; tauto. + Qed. + + Local Ltac push_rep := + repeat match goal with + | H : represents ?t ?x |- _ => unique pose proof (represents_low_range _ _ H) + | H : represents ?t ?x |- _ => unique pose proof (represents_high_range _ _ H) + | H : represents ?t ?x |- _ => rewrite (represents_low t x) in * by assumption + | H : represents ?t ?x |- _ => rewrite (represents_high t x) in * by assumption + end. + + Definition shiftr (t : list Z) (n : Z) : list Z := + [Z.rshi (2^k) (high t) (low t) n; Z.rshi (2^k) 0 (high t) n]. + + Lemma shiftr_represents a i x : + represents a x -> + 0 <= i <= k -> + represents (shiftr a i) (x / 2 ^ i). + Proof. + cbv [shiftr]; intros; push_rep. + match goal with H : _ |- _ => pose proof (represents_range _ _ H) end. + assert (0 < 2 ^ i) by auto with zarith. + assert (x < 2 ^ i * 2 ^ k * 2 ^ k) by nia. + assert (0 <= x / 2 ^ k / 2 ^ i < 2 ^ k) by (split; Z.zero_bounds; auto using Z.div_lt_upper_bound with zarith). repeat match goal with | _ => rewrite Z.rshi_correct by auto with zarith @@ -8919,377 +9288,6 @@ Module P256_32. End P256_32. *) -Require Import Coq.Program.Wf. - -Module Straightline. - Module expr. - Section with_var. - Context {var : type.type -> Type}. - Context {dummy_var : forall t, var t}. - - Let uexpr t := @Uncurried.expr.expr ident.ident var t. - - Section with_ident. - Context {ident : type.type -> type.type -> Type}. - Inductive expr : type.type -> Type := - | Scalar {t} : scalar t -> expr t - | LetInAppIdentZ {s d} : zrange -> ident s (type.Z) -> scalar s -> (var (type.Z) -> expr d) -> expr d - | LetInAppIdentZZ {s d} : zrange * zrange -> ident s (type.Z*type.Z) -> scalar s -> (var (type.Z*type.Z) -> expr d) -> expr d - with scalar : type.type -> Type := - | Var t : var t -> scalar t - | TT : scalar (type.type_primitive type.unit) - | Pair {a b} : scalar a -> scalar b -> scalar (a * b) - | Cast : zrange -> scalar type.Z -> scalar type.Z - | Cast2 : zrange * zrange -> scalar (type.Z*type.Z) -> scalar (type.Z*type.Z) - | Fst {a b} : scalar (a * b) -> scalar a - | Snd {a b} : scalar (a * b) -> scalar b - | Shiftr : Z -> scalar type.Z -> scalar type.Z - | Shiftl : Z -> scalar type.Z -> scalar type.Z - | Land : Z -> scalar type.Z -> scalar type.Z - | Primitive {t} : type.interp (type.type_primitive t) -> scalar t - . - End with_ident. - - Fixpoint dummy t : @expr ident.ident t := Scalar (Var t (dummy_var t)). - - Definition scalar_of_uncurried_ident {s d} (idc : ident.ident s d) - : scalar s -> option (scalar d) := - match idc in ident.ident s d return scalar (ident:=ident.ident) s -> option (scalar d) with - | ident.Z.cast r => fun args => Some (Cast r args) - | ident.Z.cast2 r => fun args => Some (Cast2 r args) - | @ident.fst A B => fun args => Some (Fst args) - | @ident.snd A B => fun args => Some (Snd args) - | ident.Z.shiftr n => fun args => Some (Shiftr n args) - | ident.Z.shiftl n => fun args => Some (Shiftl n args) - | ident.Z.land n => fun args => Some (Land n args) - | @ident.primitive p x => fun _ => Some (Primitive x) - | _ => fun _ => None - end. - - Fixpoint scalar_of_uncurried {t} (e : uexpr t) : option (scalar t) := - match e in Uncurried.expr.expr t return option (scalar t) with - | expr.Var t v as e => Some (Var t v) - | expr.TT as e => Some TT - | expr.Pair A B a b - => match scalar_of_uncurried a, scalar_of_uncurried b with - | Some x, Some y => Some (Pair x y) - | _, _ => None - end - | expr.AppIdent _ _ idc args - => match scalar_of_uncurried args with - | Some x => scalar_of_uncurried_ident idc x - | None => None - end - | _ => None - end. - - Fixpoint range_type t : Type := - match t with - | type.type_primitive type.Z => zrange - | type.prod x y => range_type x * range_type y - | _ => unit - end. - - Definition invert_cast {t} (e : uexpr t) - : option (range_type t * uexpr t) := - match invert_AppIdent e with - | Some (existT s (idc, x)) => - (match idc in ident.ident s t return uexpr s -> option (range_type t * uexpr t) with - | ident.Z.cast r => fun x => Some (r, x) - | ident.Z.cast2 r => fun x => Some (r, x) - | _ => fun _ => None - end) x - | None => None - end. - - (* ident.Let_In @@ (cast r x) => r, x *) - Definition invert_LetInCast {tx tC} (args : uexpr (tx * (tx -> tC))) - : option (range_type tx * uexpr tx * uexpr (tx -> tC)) := - match invert_Pair args with - | Some (x, e) => - match invert_cast x with - | Some (r, x') => Some (r, x', e) - | None => None - end - | None => None - end. - - Definition invert_LetInAppIdent {tx tC} (args : uexpr (tx * (tx -> tC))) - : option { s : type.type & (range_type tx * ident.ident s tx * scalar s * (var tx -> uexpr tC))%type } := - match invert_LetInCast args with - | Some (r, x, e) => - match invert_AppIdent x with - | Some (existT s idc_x') => - match scalar_of_uncurried (snd idc_x') with - | Some x'' => - match invert_Abs e with - | Some k => Some (existT _ s (r, fst idc_x', x'', k)) - | None => None - end - | None => None - end - | None => None - end - | None => None - end. - - Definition mk_LetInAppIdent {s d t} (default : expr t) - : range_type d -> ident.ident s d -> scalar s -> (var d -> expr t) -> expr t := - match d as d0 return range_type d0 -> ident.ident s d0 -> scalar s -> (var d0 -> expr t) -> expr t with - | type.type_primitive type.Z => - fun r idc x k => @LetInAppIdentZ ident.ident s t r idc x k - | type.prod type.Z type.Z => - fun r idc x k => @LetInAppIdentZZ ident.ident s t r idc x k - | _ => fun _ _ _ _ => default - end. - - Definition of_uncurried_step {t} (e : uexpr t) - (of_uncurried : forall {t}, uexpr t -> expr t) - : expr t -> expr t := - match e in Uncurried.expr.expr t return expr t -> expr t with - | AppIdent s d idc args => - (match idc in ident.ident s d return uexpr s -> expr d -> expr d with - | ident.Let_In tx tC => - fun args default => - match invert_LetInAppIdent args return expr tC with - | Some (existT s (r, idc, x, k)) => - @mk_LetInAppIdent s tx tC default r idc x (fun y : var tx => of_uncurried (k y)) - | None => default - end - | ident.Z.cast r => - fun (args : uexpr _) default => - match invert_AppIdent args with - | Some (existT s idc_x') => - match scalar_of_uncurried (snd idc_x') with - | Some x'' => - @mk_LetInAppIdent s type.Z type.Z default r (fst idc_x') x'' (fun y => Scalar (Var _ y)) - | None => default - end - | None => default - end - | ident.Z.cast2 r => - fun (args : uexpr _) default => - match invert_AppIdent args with - | Some (existT s idc_x') => - match scalar_of_uncurried (snd idc_x') with - | Some x'' => - @mk_LetInAppIdent s (type.Z*type.Z) (type.Z*type.Z) default r (fst idc_x') x'' (fun y => Scalar (Var _ y)) - | None => default - end - | None => default - end - | _ => fun _ default => default - end) args - | _ as e => - (fun default => - match scalar_of_uncurried e with - | Some s => Scalar s - | None => default - end) - end. - - (* TODO : uses fuel; ideally want a cleaner termination proof *) - Fixpoint of_uncurried (fuel : nat) {t} (e : uexpr t) - : expr t := - match fuel with - | S fuel' => of_uncurried_step e (@of_uncurried fuel') (dummy t) - | O => dummy t - end. - - End with_var. - End expr. - - Fixpoint depth {var t} (dummy_var : forall t, var t) (e : @Uncurried.expr.expr ident.ident var t) : nat := - match e with - | Uncurried.expr.Var _ _ => O - | Uncurried.expr.TT => O - | Uncurried.expr.AppIdent _ _ idc args => S (depth dummy_var args) - | Uncurried.expr.App _ _ f x => S (Nat.max (depth dummy_var f) (depth dummy_var x)) - | Uncurried.expr.Pair _ _ x y => S (Nat.max (depth dummy_var x) (depth dummy_var y)) - | Uncurried.expr.Abs _ _ f => S (depth dummy_var (f (dummy_var _))) - end. - - (* TODO : Can I avoid having dummy_var appear here? *) - Definition of_Expr {s d} (e : Expr (s->d)) (var : type -> Type) (x:var s) dummy_var : expr.expr d - := - match invert_Abs (e var) with - | Some f => expr.of_uncurried (dummy_var:=dummy_var) (depth dummy_var (f x)) (f x) - | None => expr.dummy (dummy_var:=dummy_var) d - end. - -End Straightline. - -Module StraightlineTest. - Definition test : Expr (type.Z -> type.Z) := - fun var => - Abs - (fun (x : var type.Z) => - AppIdent (var:=var) ident.Let_In - (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) - (Abs (fun x : var type.Z => expr.Var x)))). - - Eval vm_compute in (Straightline.of_Expr test). - - Definition test_mul : Expr (type.Z -> type.Z) := - fun var => - Abs - (fun (x : var type.Z) => - AppIdent (var:=var) ident.Let_In - (Pair (AppIdent (var:=var) (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (var:=var) (ident.Z.shiftr 8) (Var x))) - (Abs (fun y : var type.Z => - AppIdent ident.Let_In - (Pair (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent ident.Z.mul (Pair (AppIdent (@ident.primitive type.Z 12) TT) (Var y)))) - (Abs (fun z : var type.Z => (AppIdent (ident.Z.cast r[0~>4294967295]%zrange) (AppIdent (ident.Z.shiftr 3) (Var z))))) - ))))). - Eval vm_compute in (Straightline.of_Expr test_mul). -End StraightlineTest. - -(* Convert straightline code to code that uses only a certain set of identifiers *) -Module PreFancy. - Section with_var. - Import Straightline.expr. - Context {var : type -> Type} (dummy_var : forall t, var t) (log2wordsize : Z) - (constant_to_scalar : forall ident, Z -> option (@scalar var ident type.Z)). - Local Notation Z := (type.type_primitive type.Z). - Let wordsize := 2 ^ log2wordsize. - Let half_bits := log2wordsize / 2. - Let half_wordsize := 2 ^ half_bits. - - Inductive ident : type -> type -> Type := - | add : ident (Z * Z) (Z * Z) - | addc : ident (Z * Z * Z) (Z * Z) - | mulll : ident (Z * Z) Z - | mullh : ident (Z * Z) Z - | mulhl : ident (Z * Z) Z - | mulhh : ident (Z * Z) Z - | sub : ident (Z * Z) (Z * Z) - | shiftr : BinInt.Z -> ident Z Z - | shiftl : BinInt.Z -> ident Z Z - | sel : ident (Z * Z * Z) Z - | addm : ident (Z * Z * Z) Z - . - Let dummy t : @expr var ident t := Scalar (Var _ (dummy_var t)). - - Definition invert_lower' {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Cast r (Land n x) => - if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? 2^half_bits-1) - then Some x - else None - | _ => None - end. - - Definition invert_upper' {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Cast r (Shiftr n x) => - if (lower r =? 0) && (upper r =? (half_wordsize - 1)) && (n =? half_bits) - then Some x - else None - | _ => None - end. - - Definition invert_lower {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Primitive type.Z x => - match constant_to_scalar ident x with - | Some y => invert_lower' y - | None => None - end - | _ => invert_lower' e - end. - - Definition invert_upper {t} (e : @scalar var ident t) : - option (@scalar var ident Z) := - match e in scalar t return option (@scalar var ident Z) with - | Primitive type.Z x => - match constant_to_scalar ident x with - | Some y => invert_upper' y - | None => None - end - | _ => invert_upper' e - end. - - - Definition of_straightline_ident {s d} (idc : ident.ident s d) - : forall t, range_type d -> @scalar var ident s -> (var d -> @expr var ident t) -> @expr var ident t := - match idc in ident.ident s d return forall t, range_type d -> scalar s -> (var d -> @expr var ident t) -> @expr var ident t with - | ident.Z.add_get_carry_concrete w => - fun t r x f => - if w =? wordsize - then LetInAppIdentZZ r add x f - else dummy _ - | ident.Z.add_with_get_carry_concrete w => - fun t r x f => - if w =? wordsize - then LetInAppIdentZZ r addc x f - else dummy _ - | ident.Z.sub_get_borrow_concrete w => - fun t r x f => - if w =? wordsize - then LetInAppIdentZZ r sub x f - else dummy _ - | ident.Z.shiftr n => fun _ r => LetInAppIdentZ r (shiftr n) - | ident.Z.shiftl n => fun _ r => LetInAppIdentZ r (shiftl n) - | ident.Z.zselect => fun _ r => LetInAppIdentZ r sel - | ident.Z.add_modulo => fun _ r => LetInAppIdentZ r addm - | ident.Z.mul => - fun t r x f => - match x return expr t with - | Pair _ _ x0 x1 => - match invert_lower x0, invert_lower x1 with - | Some y0, Some y1 => LetInAppIdentZ r mulll (Pair y0 y1) f - | Some y0, None => - match invert_upper x1 with - | Some y1 => LetInAppIdentZ r mullh (Pair y0 y1) f - | None => dummy _ - end - | None, Some y1 => - match invert_upper x0 with - | Some y0 => LetInAppIdentZ r mulhl (Pair y0 y1) f - | None => dummy _ - end - | None, None => - match invert_upper x0, invert_upper x1 with - | Some y0, Some y1 => LetInAppIdentZ r mulhh (Pair y0 y1) f - | _,_ => dummy _ - end - end - | _ => dummy _ - end - | _ => fun t _ _ _ => dummy t - end. - - Fixpoint of_straightline_scalar {t} (s : @scalar var ident.ident t) - : @scalar var ident t := - match s with - | Var _ v => Var _ v - | TT => TT - | Pair _ _ x y => Pair (of_straightline_scalar x) (of_straightline_scalar y) - | Cast r x => Cast r (of_straightline_scalar x) - | Cast2 r x => Cast2 r (of_straightline_scalar x) - | Fst _ _ x => Fst (of_straightline_scalar x) - | Snd _ _ x => Snd (of_straightline_scalar x) - | Shiftr n x => Shiftr n (of_straightline_scalar x) - | Shiftl n x => Shiftl n (of_straightline_scalar x) - | Land n x => Land n (of_straightline_scalar x) - | Primitive _ x => Primitive x - end. - - Fixpoint of_straightline {t} (e : @expr var ident.ident t) - : @expr var ident t := - match e with - | Scalar _ s => Scalar (of_straightline_scalar s) - | LetInAppIdentZ _ t r idc x f => - of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) - | LetInAppIdentZZ _ t r idc x f => - of_straightline_ident idc t r (of_straightline_scalar x) (fun y => of_straightline (f y)) - end. - End with_var. -End PreFancy. - Module MontgomeryReduction. Section MontRed'. Context (N R N' R' : Z). -- cgit v1.2.3