diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/BoundedArithmetic/ArchitectureToZLike.v | 16 | ||||
-rw-r--r-- | src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 21 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBounded.v | 26 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBoundedProofs.v | 34 | ||||
-rw-r--r-- | src/ModularArithmetic/Montgomery/ZBounded.v | 16 | ||||
-rw-r--r-- | src/Reflection/Named/Compile.v | 65 | ||||
-rw-r--r-- | src/Reflection/Named/ContextOn.v | 16 | ||||
-rw-r--r-- | src/Reflection/Named/DeadCodeElimination.v | 70 | ||||
-rw-r--r-- | src/Reflection/Named/EstablishLiveness.v | 109 | ||||
-rw-r--r-- | src/Reflection/Named/RegisterAssign.v | 124 | ||||
-rw-r--r-- | src/Reflection/Named/Syntax.v | 200 | ||||
-rw-r--r-- | src/Reflection/TestCase.v | 15 | ||||
-rw-r--r-- | src/Specific/FancyMachine256/Barrett.v | 137 | ||||
-rw-r--r-- | src/Specific/FancyMachine256/Core.v | 380 | ||||
-rw-r--r-- | src/Specific/FancyMachine256/Montgomery.v | 148 |
15 files changed, 1050 insertions, 327 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index 3388ece78..939265c1e 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -4,6 +4,7 @@ Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BoundedArithmetic.DoubleBounded. Require Import Crypto.ModularArithmetic.ZBounded. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.LetIn. Local Open Scope Z_scope. @@ -12,21 +13,26 @@ Section fancy_machine_p256_montgomery_foundation. Local Notation n := (2 * n_over_two). Context (ops : fancy_machine.instructions n) (modulus : Z). - Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) + Local Instance ZLikeOps_of_ArchitectureBoundedOps_Factored (smaller_bound_exp : Z) + ldi_modulus ldi_0 : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := { LargeT := tuple fancy_machine.W 2; SmallT := fancy_machine.W; - modulus_digits := ldi modulus; + modulus_digits := ldi_modulus; decode_large := decode; decode_small := decode; Mod_SmallBound v := fst v; DivBy_SmallBound v := snd v; DivBy_SmallerBound v := if smaller_bound_exp =? n then snd v - else shrd (snd v) (fst v) smaller_bound_exp; + else dlet v := v in shrd (snd v) (fst v) smaller_bound_exp; Mul x y := muldw x y; CarryAdd x y := adc x y false; CarrySubSmall x y := subc x y false; - ConditionalSubtract b x := let v := selc b (ldi modulus) (ldi 0) in snd (subc x v false); - ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }. + ConditionalSubtract b x := let v := selc b (ldi_modulus) (ldi_0) in snd (subc x v false); + ConditionalSubtractModulus y := addm y (ldi_0) (ldi_modulus) }. + + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) + : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := + @ZLikeOps_of_ArchitectureBoundedOps_Factored smaller_bound_exp (ldi modulus) (ldi 0). End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v index 804296374..aca1753b7 100644 --- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -8,6 +8,7 @@ Require Import Crypto.BoundedArithmetic.ArchitectureToZLike. Require Import Crypto.ModularArithmetic.ZBounded. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ZUtil Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. Local Open Scope nat_scope. Local Open Scope Z_scope. @@ -37,7 +38,7 @@ Section fancy_machine_p256_montgomery_foundation. pose proof (decode_range x) end. Local Ltac unfolder_t := - progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps in *. + progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps, ZLikeOps_of_ArchitectureBoundedOps_Factored in *. Local Ltac saturate_context_step := match goal with | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] @@ -53,6 +54,8 @@ Section fancy_machine_p256_montgomery_foundation. Local Ltac post_t_step := match goal with | _ => reflexivity + | _ => progress subst + | _ => progress unfold Let_In | _ => progress autorewrite with zsimplify_const | [ |- fst ?x = (?a <=? ?b) :> bool ] => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); @@ -69,13 +72,16 @@ Section fancy_machine_p256_montgomery_foundation. Local Ltac post_t := repeat post_t_step. Local Ltac t := pre_t; post_t. - Global Instance ZLikeProperties_of_ArchitectureBoundedOps + Global Instance ZLikeProperties_of_ArchitectureBoundedOps_Factored {arith : fancy_machine.arithmetic ops} + ldi_modulus ldi_0 + (Hldi_modulus : ldi_modulus = ldi modulus) + (Hldi_0 : ldi_0 = ldi 0) (modulus_in_range : 0 <= modulus < 2^n) (smaller_bound_exp : Z) (smaller_bound_smaller : 0 <= smaller_bound_exp <= n) (n_pos : 0 < n) - : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps_Factored ops modulus smaller_bound_exp ldi_modulus ldi_0) := { large_valid v := True; medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp; small_valid v := True }. @@ -107,4 +113,13 @@ Section fancy_machine_p256_montgomery_foundation. { abstract t. } { abstract t. } Defined. + + Global Instance ZLikeProperties_of_ArchitectureBoundedOps + {arith : fancy_machine.arithmetic ops} + (modulus_in_range : 0 <= modulus < 2^n) + (smaller_bound_exp : Z) + (smaller_bound_smaller : 0 <= smaller_bound_exp <= n) + (n_pos : 0 < n) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) + := ZLikeProperties_of_ArchitectureBoundedOps_Factored _ _ eq_refl eq_refl modulus_in_range _ smaller_bound_smaller n_pos. End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index b624c5082..5cf48cd3b 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -6,6 +6,7 @@ Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. Local Open Scope nat_scope. Local Open Scope Z_scope. @@ -27,10 +28,14 @@ Section ripple_carry_definitions. : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with | O => f - | S k' => fun xss yss carry => let '(xs, x) := eta xss in + | S k' => fun xss yss carry => dlet xss := xss in + dlet yss := yss in + let '(xs, x) := eta xss in let '(ys, y) := eta yss in - let '(carry, zs) := eta (@ripple_carry_tuple' _ f k' xs ys carry) in - let '(carry, z) := eta (f x y carry) in + dlet addv := (@ripple_carry_tuple' _ f k' xs ys carry) in + let '(carry, zs) := eta addv in + dlet fxy := (f x y carry) in + let '(carry, z) := eta fxy in (carry, (zs, z)) end. @@ -75,11 +80,16 @@ Section tuple2. {ldi : load_immediate W}. Definition mul_double (a b : W) : tuple W 2 - := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in - let tmp := mulhwhl a b in - let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in - let tmp := mulhwhl b a in - let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + := dlet a := a in + dlet b := b in + let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in + dlet out := out in + dlet tmp := mulhwhl a b in + dlet addv := (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + let '(_, out) := eta addv in + dlet tmp := mulhwhl b a in + dlet addv := (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + let '(_, out) := eta addv in out. (** Require a dummy [decoder] for these instances to allow diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v index 53ac59d00..8fae01f9f 100644 --- a/src/BoundedArithmetic/DoubleBoundedProofs.v +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -12,6 +12,7 @@ Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Tactics. Require Import Crypto.Util.Notations. +Require Import Crypto.Util.LetIn. Import ListNotations. Local Open Scope list_scope. @@ -235,6 +236,31 @@ Global Instance decode_mul_double : forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z := proj1 decode_mul_double_iff _. + +Lemma ripple_carry_tuple_SS' {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = dlet xss := xss in + dlet yss := yss in + let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + dlet addv := (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, zs) := eta addv in + dlet fxy := (f x y carry) in + let '(carry, z) := eta fxy in + (carry, (zs, z)). +Proof. reflexivity. Qed. + +(* This turns a goal like [x = Let_In p (fun v => let '(x, y) := f v + in x + y)] into a goal like [x = fst (f p) + snd (f p)]. Note that + it inlines [Let_In] as well as destructuring lets. *) +Local Ltac eta_expand := + repeat match goal with + | _ => progress unfold Let_In + | [ |- context[let '(x, y) := ?e in _] ] + => rewrite (surjective_pairing e) + | _ => rewrite <- !surjective_pairing + end. + Lemma ripple_carry_tuple_SS {T} f k xss yss carry : @ripple_carry_tuple T f (S (S k)) xss yss carry = let '(xs, x) := eta xss in @@ -242,7 +268,11 @@ Lemma ripple_carry_tuple_SS {T} f k xss yss carry let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in let '(carry, z) := eta (f x y carry) in (carry, (zs, z)). -Proof. reflexivity. Qed. +Proof. + rewrite ripple_carry_tuple_SS'. + eta_expand. + reflexivity. +Qed. Lemma carry_is_good (n z0 z1 k : Z) : 0 <= n -> @@ -414,7 +444,7 @@ Section tuple2. Proof. assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative. assert (0 <= half_n) by omega. - unfold mul_double. + unfold mul_double; eta_expand. push_decode; autorewrite with simpl_tuple_decoder; simplify_projections. autorewrite with zsimplify Zshift_to_pow push_Zpow. rewrite !spread_left_from_shift_half_correct. diff --git a/src/ModularArithmetic/Montgomery/ZBounded.v b/src/ModularArithmetic/Montgomery/ZBounded.v index 2da20ddcf..97bcf87b9 100644 --- a/src/ModularArithmetic/Montgomery/ZBounded.v +++ b/src/ModularArithmetic/Montgomery/ZBounded.v @@ -10,6 +10,7 @@ Require Import Crypto.ModularArithmetic.ZBounded. Require Import Crypto.BaseSystem. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.LetIn. Require Import Crypto.Util.Notations. Local Open Scope small_zlike_scope. @@ -22,6 +23,19 @@ Section montgomery. (modulus'_valid : small_valid modulus') (modulus_nonzero : modulus <> 0). + (** pull out a common subexpression *) + Local Ltac cse := + let RHS := match goal with |- _ = ?decode ?RHS /\ _ => RHS end in + let v := fresh in + match RHS with + | context[?e] => not is_var e; set (v := e) at 1 2; test clearbody v + end; + revert v; + match goal with + | [ |- let v := ?val in ?LHS = ?decode ?RHS /\ ?P ] + => change (LHS = decode (dlet v := val in RHS) /\ P) + end. + Definition partial_reduce : forall v : LargeT, { partial_reduce : SmallT | large_valid v @@ -38,6 +52,7 @@ Section montgomery. rewrite <- partial_reduce_alt_eq by omega. cbv [Montgomery.Z.partial_reduce Montgomery.Z.partial_reduce_alt Montgomery.Z.prereduce]. pull_zlike_decode. + cse. subst pr; split; [ reflexivity | exact _ ]. Defined. @@ -58,6 +73,7 @@ Section montgomery. rewrite <- partial_reduce_alt_eq by omega. cbv [Montgomery.Z.partial_reduce Montgomery.Z.partial_reduce_alt Montgomery.Z.prereduce]. pull_zlike_decode. + cse. subst pr; split; [ reflexivity | exact _ ]. Defined. diff --git a/src/Reflection/Named/Compile.v b/src/Reflection/Named/Compile.v new file mode 100644 index 000000000..e37815597 --- /dev/null +++ b/src/Reflection/Named/Compile.v @@ -0,0 +1,65 @@ +(** * PHOAS → Named Representation of Gallina *) +Require Import Crypto.Reflection.Named.Syntax. +Require Import Crypto.Reflection.Named.NameUtil. +Require Import Crypto.Reflection.Syntax. + +Local Notation eta x := (fst x, snd x). + +Local Open Scope ctype_scope. +Local Open Scope nexpr_scope. +Local Open Scope expr_scope. +Section language. + Context (base_type_code : Type) + (interp_base_type : base_type_code -> Type) + (op : flat_type base_type_code -> flat_type base_type_code -> Type) + (Name : Type). + + Local Notation flat_type := (flat_type base_type_code). + Local Notation type := (type base_type_code). + Let Tbase := @Tbase base_type_code. + Local Coercion Tbase : base_type_code >-> Syntax.flat_type. + Local Notation interp_type := (interp_type interp_base_type). + Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type). + Local Notation exprf := (@exprf base_type_code interp_base_type op (fun _ => Name)). + Local Notation expr := (@expr base_type_code interp_base_type op (fun _ => Name)). + Local Notation nexprf := (@Named.exprf base_type_code interp_base_type op Name). + Local Notation nexpr := (@Named.expr base_type_code interp_base_type op Name). + + Fixpoint ocompilef {t} (e : exprf t) (ls : list (option Name)) {struct e} + : option (nexprf t) + := match e in @Syntax.exprf _ _ _ _ t return option (nexprf t) with + | Const _ x => Some (Named.Const x) + | Var _ x => Some (Named.Var x) + | Op _ _ op args => option_map (Named.Op op) (@ocompilef _ args ls) + | LetIn tx ex _ eC + => match @ocompilef _ ex nil, split_onames tx ls with + | Some x, (Some n, ls')%core + => option_map (fun C => Named.LetIn tx n x C) (@ocompilef _ (eC n) ls') + | _, _ => None + end + | Pair _ ex _ ey => match @ocompilef _ ex nil, @ocompilef _ ey nil with + | Some x, Some y => Some (Named.Pair x y) + | _, _ => None + end + end. + + Fixpoint ocompile {t} (e : expr t) (ls : list (option Name)) {struct e} + : option (nexpr t) + := match e in @Syntax.expr _ _ _ _ t return option (nexpr t) with + | Return _ x => option_map Named.Return (ocompilef x ls) + | Abs _ _ f + => match ls with + | cons (Some n) ls' + => option_map (Named.Abs n) (@ocompile _ (f n) ls') + | _ => None + end + end. + + Definition compilef {t} (e : exprf t) (ls : list Name) := @ocompilef t e (List.map (@Some _) ls). + Definition compile {t} (e : expr t) (ls : list Name) := @ocompile t e (List.map (@Some _) ls). +End language. + +Global Arguments ocompilef {_ _ _ _ _} e ls. +Global Arguments ocompile {_ _ _ _ _} e ls. +Global Arguments compilef {_ _ _ _ _} e ls. +Global Arguments compile {_ _ _ _ _} e ls. diff --git a/src/Reflection/Named/ContextOn.v b/src/Reflection/Named/ContextOn.v new file mode 100644 index 000000000..d32911283 --- /dev/null +++ b/src/Reflection/Named/ContextOn.v @@ -0,0 +1,16 @@ +(** * Transfer [Context] across an injection *) +Require Import Crypto.Reflection.Named.Syntax. + +Section language. + Context {base_type_code Name1 Name2 : Type} + (f : Name2 -> Name1) + (f_inj : forall x y, f x = f y -> x = y) + {var : base_type_code -> Type}. + + Definition ContextOn (Ctx : Context Name1 var) : Context Name2 var + := {| ContextT := Ctx; + lookupb ctx n t := lookupb ctx (f n) t; + extendb ctx n t v := extendb ctx (f n) v; + removeb ctx n t := removeb ctx (f n) t; + empty := empty |}. +End language. diff --git a/src/Reflection/Named/DeadCodeElimination.v b/src/Reflection/Named/DeadCodeElimination.v new file mode 100644 index 000000000..1b2fa3fc0 --- /dev/null +++ b/src/Reflection/Named/DeadCodeElimination.v @@ -0,0 +1,70 @@ +(** * PHOAS → Named Representation of Gallina *) +Require Import Coq.PArith.BinPos Coq.Lists.List. +Require Import Crypto.Reflection.Named.Syntax. +Require Import Crypto.Reflection.Named.Compile. +Require Import Crypto.Reflection.Named.RegisterAssign. +Require Import Crypto.Reflection.Named.EstablishLiveness. +Require Import Crypto.Reflection.CountLets. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.LetIn. + +Local Notation eta x := (fst x, snd x). + +Local Open Scope ctype_scope. +Local Open Scope nexpr_scope. +Local Open Scope expr_scope. +Section language. + Context (base_type_code : Type) + (interp_base_type : base_type_code -> Type) + (op : flat_type base_type_code -> flat_type base_type_code -> Type) + (Name : Type) + {Context : Context Name (fun _ : base_type_code => positive)}. + + Local Notation flat_type := (flat_type base_type_code). + Local Notation type := (type base_type_code). + Let Tbase := @Tbase base_type_code. + Local Coercion Tbase : base_type_code >-> Syntax.flat_type. + Local Notation interp_type := (interp_type interp_base_type). + Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type). + Local Notation exprf := (@exprf base_type_code interp_base_type op (fun _ => Name)). + Local Notation expr := (@expr base_type_code interp_base_type op (fun _ => Name)). + Local Notation Expr := (@Expr base_type_code interp_base_type op). + (*Local Notation lexprf := (@Syntax.exprf base_type_code interp_base_type op (fun _ => list (option Name))). + Local Notation lexpr := (@Syntax.expr base_type_code interp_base_type op (fun _ => list (option Name))).*) + Local Notation nexprf := (@Named.exprf base_type_code interp_base_type op Name). + Local Notation nexpr := (@Named.expr base_type_code interp_base_type op Name). + + (*Definition get_live_namesf (names : list (option Name)) {t} (e : lexprf t) : list (option Name) + := filter_live_namesf + base_type_code interp_base_type op + (option Name) None + (fun x y => match x, y with + | Some x, _ => Some x + | _, Some y => Some y + | None, None => None + end) + nil names e. + Definition get_live_names (names : list (option Name)) {t} (e : lexpr t) : list (option Name) + := filter_live_names + base_type_code interp_base_type op + (option Name) None + (fun x y => match x, y with + | Some x, _ => Some x + | _, Some y => Some y + | None, None => None + end) + nil names e.*) + + Definition CompileAndEliminateDeadCode + {t} (e : Expr t) (ls : list Name) + : option (nexpr t) + := let e := compile (Name:=positive) (e _) (List.map Pos.of_nat (seq 1 (CountBinders e))) in + match e with + | Some e => Let_In (insert_dead_names None e ls) (* help vm_compute by factoring this out *) + (fun names => register_reassign Pos.eqb empty empty e names) + | None => None + end. +End language. + +Global Arguments CompileAndEliminateDeadCode {_ _ _ _ _ t} e ls. diff --git a/src/Reflection/Named/EstablishLiveness.v b/src/Reflection/Named/EstablishLiveness.v new file mode 100644 index 000000000..2301eb6a1 --- /dev/null +++ b/src/Reflection/Named/EstablishLiveness.v @@ -0,0 +1,109 @@ +(** * Compute a list of liveness values for each binding *) +Require Import Coq.Lists.List. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Named.Syntax. +Require Import Crypto.Reflection.CountLets. +Require Import Crypto.Util.ListUtil. + +Local Notation eta x := (fst x, snd x). + +Local Open Scope ctype_scope. +Delimit Scope nexpr_scope with nexpr. + +Inductive liveness := live | dead. +Fixpoint merge_liveness (ls1 ls2 : list liveness) := + match ls1, ls2 with + | cons x xs, cons y ys + => cons match x, y with + | live, _ + | _, live + => live + | dead, dead + => dead + end + (@merge_liveness xs ys) + | nil, ls + | ls, nil + => ls + end. + +Section language. + Context (base_type_code : Type) + (interp_base_type : base_type_code -> Type) + (op : flat_type base_type_code -> flat_type base_type_code -> Type). + + Local Notation flat_type := (flat_type base_type_code). + Local Notation type := (type base_type_code). + Let Tbase := @Tbase base_type_code. + Local Coercion Tbase : base_type_code >-> Syntax.flat_type. + Local Notation interp_type := (interp_type interp_base_type). + Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type). + Local Notation exprf := (@exprf base_type_code interp_base_type op). + Local Notation expr := (@expr base_type_code interp_base_type op). + + Section internal. + Context (Name : Type) + (OutName : Type) + {Context : Context Name (fun _ : base_type_code => list liveness)}. + + Definition compute_livenessf_step + (compute_livenessf : forall (ctx : Context) {t} (e : exprf Name t) (prefix : list liveness), list liveness) + (ctx : Context) + {t} (e : exprf Name t) (prefix : list liveness) + : list liveness + := match e with + | Const _ x => prefix + | Var t' name => match lookup ctx t' name with + | Some ls => ls + | _ => nil + end + | Op _ _ op args + => @compute_livenessf ctx _ args prefix + | LetIn tx n ex _ eC + => let lx := @compute_livenessf ctx _ ex prefix in + let lx := merge_liveness lx (prefix ++ repeat live (count_pairs tx)) in + let ctx := extend ctx n (SmartVal _ (fun _ => lx) tx) in + @compute_livenessf ctx _ eC (prefix ++ repeat dead (count_pairs tx)) + | Pair _ ex _ ey + => merge_liveness (@compute_livenessf ctx _ ex prefix) + (@compute_livenessf ctx _ ey prefix) + end. + + Fixpoint compute_livenessf ctx {t} e prefix + := @compute_livenessf_step (@compute_livenessf) ctx t e prefix. + + Fixpoint compute_liveness (ctx : Context) + {t} (e : expr Name t) (prefix : list liveness) + : list liveness + := match e with + | Return _ x => compute_livenessf ctx x prefix + | Abs src _ n f + => let prefix := prefix ++ (live::nil) in + let ctx := extendb (t:=src) ctx n prefix in + @compute_liveness ctx _ f prefix + end. + + Section insert_dead. + Context (default_out : option OutName). + + Fixpoint insert_dead_names_gen (ls : list liveness) (lsn : list OutName) + : list (option OutName) + := match ls with + | nil => nil + | cons live xs + => match lsn with + | cons n lsn' => Some n :: @insert_dead_names_gen xs lsn' + | nil => default_out :: @insert_dead_names_gen xs nil + end + | cons dead xs + => None :: @insert_dead_names_gen xs lsn + end. + Definition insert_dead_names {t} (e : expr Name t) + := insert_dead_names_gen (compute_liveness empty e nil). + End insert_dead. + End internal. +End language. + +Global Arguments compute_livenessf {_ _ _ _ _} ctx {t} e prefix. +Global Arguments compute_liveness {_ _ _ _ _} ctx {t} e prefix. +Global Arguments insert_dead_names {_ _ _ _ _ _} default_out {t} e lsn. diff --git a/src/Reflection/Named/RegisterAssign.v b/src/Reflection/Named/RegisterAssign.v new file mode 100644 index 000000000..5736d01a3 --- /dev/null +++ b/src/Reflection/Named/RegisterAssign.v @@ -0,0 +1,124 @@ +(** * Reassign registers *) +Require Import Coq.FSets.FMapPositive Coq.PArith.BinPos. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Reflection.Named.Syntax. +Require Import Crypto.Reflection.Named.NameUtil. +Require Import Crypto.Util.Decidable. + +Local Notation eta x := (fst x, snd x). + +Local Open Scope ctype_scope. +Delimit Scope nexpr_scope with nexpr. +Section language. + Context (base_type_code : Type) + (interp_base_type : base_type_code -> Type) + (op : flat_type base_type_code -> flat_type base_type_code -> Type). + + Local Notation flat_type := (flat_type base_type_code). + Local Notation type := (type base_type_code). + Let Tbase := @Tbase base_type_code. + Local Coercion Tbase : base_type_code >-> Syntax.flat_type. + Local Notation interp_type := (interp_type interp_base_type). + Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type). + Local Notation exprf := (@exprf base_type_code interp_base_type op). + Local Notation expr := (@expr base_type_code interp_base_type op). + + Section internal. + Context (InName OutName : Type) + {InContext : Context InName (fun _ : base_type_code => OutName)} + {ReverseContext : Context OutName (fun _ : base_type_code => InName)} + (InName_beq : InName -> InName -> bool). + + Definition register_reassignf_step + (register_reassignf : forall (ctxi : InContext) (ctxr : ReverseContext) + {t} (e : exprf InName t) (new_names : list (option OutName)), + option (exprf OutName t)) + (ctxi : InContext) (ctxr : ReverseContext) + {t} (e : exprf InName t) (new_names : list (option OutName)) + : option (exprf OutName t) + := match e in Named.exprf _ _ _ _ t return option (exprf _ t) with + | Const _ x => Some (Const x) + | Var t' name => match lookupb ctxi name t' with + | Some new_name + => match lookupb ctxr new_name t' with + | Some name' + => if InName_beq name name' + then Some (Var new_name) + else None + | None => None + end + | None => None + end + | Op _ _ op args + => option_map (Op op) (@register_reassignf ctxi ctxr _ args new_names) + | LetIn tx n ex _ eC + => let '(n', new_names') := eta (split_onames tx new_names) in + match n', @register_reassignf ctxi ctxr _ ex nil with + | Some n', Some x + => let ctxi := extend ctxi n n' in + let ctxr := extend ctxr n' n in + option_map (LetIn tx n' x) (@register_reassignf ctxi ctxr _ eC new_names') + | _, _ + => let ctxi := remove ctxi n in + @register_reassignf ctxi ctxr _ eC new_names' + end + | Pair _ ex _ ey + => match @register_reassignf ctxi ctxr _ ex nil, @register_reassignf ctxi ctxr _ ey nil with + | Some x, Some y + => Some (Pair x y) + | _, _ => None + end + end. + Fixpoint register_reassignf ctxi ctxr {t} e new_names + := @register_reassignf_step (@register_reassignf) ctxi ctxr t e new_names. + + Fixpoint register_reassign (ctxi : InContext) (ctxr : ReverseContext) + {t} (e : expr InName t) (new_names : list (option OutName)) + : option (expr OutName t) + := match e in Named.expr _ _ _ _ t return option (expr _ t) with + | Return _ x => option_map Return (register_reassignf ctxi ctxr x new_names) + | Abs src _ n f + => let '(n', new_names') := eta (split_onames src new_names) in + match n' with + | Some n' + => let ctxi := extendb (t:=src) ctxi n n' in + let ctxr := extendb (t:=src) ctxr n' n in + option_map (Abs n') (@register_reassign ctxi ctxr _ f new_names') + | None => None + end + end. + End internal. + + Section context_pos. + Global Instance pos_context {decR : DecidableRel (@eq base_type_code)} + (var : base_type_code -> Type) : Context positive var + := { ContextT := PositiveMap.t { t : _ & var t }; + lookupb ctx key t := match PositiveMap.find key ctx with + | Some v => match dec (projT1 v = t) with + | left pf => Some match pf in (_ = t) return var t with + | eq_refl => projT2 v + end + | right _ => None + end + | None => None + end; + extendb ctx key t v := PositiveMap.add key (existT _ t v) ctx; + removeb ctx key t := if dec (option_map (@projT1 _ _) (PositiveMap.find key ctx) = Some t) + then PositiveMap.remove key ctx + else ctx; + empty := PositiveMap.empty _ }. + Global Instance pos_context_nd + (var : Type) + : Context positive (fun _ : base_type_code => var) + := { ContextT := PositiveMap.t var; + lookupb ctx key t := PositiveMap.find key ctx; + extendb ctx key t v := PositiveMap.add key v ctx; + removeb ctx key t := PositiveMap.remove key ctx; + empty := PositiveMap.empty _ }. + End context_pos. +End language. + +Global Arguments pos_context {_ _} var. +Global Arguments pos_context_nd : clear implicits. +Global Arguments register_reassign {_ _ _ _ _ _ _} _ ctxi ctxr {t} e _. +Global Arguments register_reassignf {_ _ _ _ _ _ _} _ ctxi ctxr {t} e _. diff --git a/src/Reflection/Named/Syntax.v b/src/Reflection/Named/Syntax.v new file mode 100644 index 000000000..70925c16b --- /dev/null +++ b/src/Reflection/Named/Syntax.v @@ -0,0 +1,200 @@ +(** * Named Representation of Gallina *) +Require Import Coq.Classes.RelationClasses. +Require Import Crypto.Reflection.Syntax. +Require Import Crypto.Util.PointedProp. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Notations. + +Class Context {base_type_code} (Name : Type) (var : base_type_code -> Type) := + { ContextT : Type; + lookupb : ContextT -> Name -> forall {t : base_type_code}, option (var t); + extendb : ContextT -> Name -> forall {t : base_type_code}, var t -> ContextT; + removeb : ContextT -> Name -> base_type_code -> ContextT; + empty : ContextT }. +Coercion ContextT : Context >-> Sortclass. +Arguments ContextT {_ _ _ _}, {_ _ _} _. +Arguments lookupb {_ _ _ _} _ _ {_}, {_ _ _ _} _ _ _. +Arguments extendb {_ _ _ _} _ _ [_] _. +Arguments removeb {_ _ _ _} _ _ _. +Arguments empty {_ _ _ _}. + +Local Open Scope ctype_scope. +Local Open Scope expr_scope. +Delimit Scope nexpr_scope with nexpr. +Module Export Named. + Section language. + Context (base_type_code : Type) + (interp_base_type : base_type_code -> Type) + (op : flat_type base_type_code -> flat_type base_type_code -> Type) + (Name : Type). + + Local Notation flat_type := (flat_type base_type_code). + Local Notation type := (type base_type_code). + Let Tbase := @Tbase base_type_code. + Local Coercion Tbase : base_type_code >-> Syntax.flat_type. + Local Notation interp_type := (interp_type interp_base_type). + Local Notation interp_flat_type := (interp_flat_type_gen interp_base_type). + + + Section expr_param. + Section expr. + Inductive exprf : flat_type -> Type := + | Const {t : flat_type} : interp_type t -> exprf t + | Var {t : base_type_code} : Name -> exprf t + | Op {t1 tR} : op t1 tR -> exprf t1 -> exprf tR + | LetIn : forall {tx}, interp_flat_type_gen (fun _ => Name) tx -> exprf tx -> forall {tC}, exprf tC -> exprf tC + | Pair : forall {t1}, exprf t1 -> forall {t2}, exprf t2 -> exprf (Prod t1 t2). + Bind Scope nexpr_scope with exprf. + Inductive expr : type -> Type := + | Return {t} : exprf t -> expr t + | Abs {src dst} : Name -> expr dst -> expr (Arrow src dst). + Bind Scope nexpr_scope with expr. + Global Coercion Return : exprf >-> expr. + (** [SmartVar] is like [Var], except that it inserts + pair-projections and [Pair] as necessary to handle + [flat_type], and not just [base_type_code] *) + Definition SmartVar {t} : interp_flat_type_gen (fun _ => Name) t -> exprf t + := smart_interp_flat_map (f:=fun _ => Name) (g:=exprf) _ (fun t => Var) (fun A B x y => Pair x y). + Definition SmartConst {t} : interp_flat_type t -> @interp_flat_type_gen base_type_code exprf t + := smart_interp_flat_map (g:=@interp_flat_type_gen base_type_code exprf) _ (fun t => Const (t:=t)) (fun A B x y => pair x y). + End expr. + + Section with_context. + Context {var : base_type_code -> Type} + {Context : Context Name var}. + + Fixpoint extend (ctx : Context) {t : flat_type} + (n : interp_flat_type_gen (fun _ => Name) t) (v : interp_flat_type_gen var t) + : Context + := match t return interp_flat_type_gen (fun _ => Name) t -> interp_flat_type_gen var t -> Context with + | Syntax.Tbase t => fun n v => extendb ctx n v + | Prod A B => fun n v + => let ctx := @extend ctx A (fst n) (fst v) in + let ctx := @extend ctx B (snd n) (snd v) in + ctx + end n v. + + Fixpoint remove (ctx : Context) {t : flat_type} + (n : interp_flat_type_gen (fun _ => Name) t) + : Context + := match t return interp_flat_type_gen (fun _ => Name) t -> Context with + | Syntax.Tbase t => fun n => removeb ctx n t + | Prod A B => fun n + => let ctx := @remove ctx A (fst n) in + let ctx := @remove ctx B (snd n) in + ctx + end n. + + Definition lookup (ctx : Context) {t} + : interp_flat_type_gen (fun _ => Name) t -> option (interp_flat_type_gen var t) + := smart_interp_flat_map + base_type_code + (g := fun t => option (interp_flat_type_gen var t)) + (fun t v => lookupb ctx v) + (fun A B x y => match x, y with + | Some x', Some y' => Some (x', y')%core + | _, _ => None + end). + + Section wf. + Fixpoint wff (ctx : Context) {t} (e : exprf t) : option pointed_Prop + := match e with + | Const _ x => Some trivial + | Var t n => match lookupb ctx n t return bool with + | Some _ => true + | None => false + end + | Op _ _ op args => @wff ctx _ args + | LetIn _ n ex _ eC => @wff ctx _ ex /\ inject (forall v, prop_of_option (@wff (extend ctx n v) _ eC)) + | Pair _ ex _ ey => @wff ctx _ ex /\ @wff ctx _ ey + end%option_pointed_prop. + + Fixpoint wf (ctx : Context) {t} (e : expr t) : Prop + := match e with + | Return _ x => prop_of_option (wff ctx x) + | Abs src _ n f => forall v, @wf (extend ctx (t:=src) n v) _ f + end. + End wf. + + Section interp_gen. + Context (output_interp_flat_type : flat_type -> Type) + (interp_const : forall t, interp_flat_type t -> output_interp_flat_type t) + (interp_var : forall t, var t -> output_interp_flat_type t) + (interp_op : forall src dst, op src dst -> output_interp_flat_type src -> output_interp_flat_type dst) + (interp_let : forall (tx : flat_type) (ex : output_interp_flat_type tx) + tC (eC : interp_flat_type_gen var tx -> output_interp_flat_type tC), + output_interp_flat_type tC) + (interp_pair : forall (tx : flat_type) (ex : output_interp_flat_type tx) + (ty : flat_type) (ey : output_interp_flat_type ty), + output_interp_flat_type (Prod tx ty)). + + Fixpoint interp_genf (ctx : Context) {t} (e : exprf t) + : prop_of_option (wff ctx e) -> output_interp_flat_type t + := match e in exprf t return prop_of_option (wff ctx e) -> output_interp_flat_type t with + | Const _ x => fun _ => interp_const _ x + | Var t' x => match lookupb ctx x t' as v + return prop_of_option (match v return bool with + | Some _ => true + | None => false + end) + -> output_interp_flat_type t' + with + | Some v => fun _ => interp_var _ v + | None => fun bad => match bad : False with end + end + | Op _ _ op args => fun good => @interp_op _ _ op (@interp_genf ctx _ args good) + | LetIn _ n ex _ eC => fun good => let goodxC := proj1 (@prop_of_option_and _ _) good in + let x := @interp_genf ctx _ ex (proj1 goodxC) in + interp_let _ x _ (fun x => @interp_genf (extend ctx n x) _ eC (proj2 goodxC x)) + | Pair _ ex _ ey => fun good => let goodxy := proj1 (@prop_of_option_and _ _) good in + let x := @interp_genf ctx _ ex (proj1 goodxy) in + let y := @interp_genf ctx _ ey (proj2 goodxy) in + interp_pair _ x _ y + end. + End interp_gen. + End with_context. + + Section with_val_context. + Context (Context : Context Name interp_base_type) + (interp_op : forall src dst, op src dst -> interp_flat_type src -> interp_flat_type dst). + Definition interpf + : forall (ctx : Context) {t} (e : exprf t), + prop_of_option (wff ctx e) -> interp_flat_type t + := @interp_genf + interp_base_type Context interp_flat_type + (fun _ x => x) (fun _ x => x) interp_op (fun _ y _ f => let x := y in f x) + (fun _ x _ y => (x, y)%core). + + Fixpoint interp (ctx : Context) {t} (e : expr t) + : wf ctx e -> interp_type t + := match e in expr t return wf ctx e -> interp_type t with + | Return _ x => interpf ctx x + | Abs _ _ n f => fun good v => @interp _ _ f (good v) + end. + End with_val_context. + End expr_param. + End language. +End Named. + +Global Arguments Const {_ _ _ _ _} _. +Global Arguments Var {_ _ _ _ _} _. +Global Arguments SmartVar {_ _ _ _ _} _. +Global Arguments SmartConst {_ _ _ _ _} _. +Global Arguments Op {_ _ _ _ _ _} _ _. +Global Arguments LetIn {_ _ _ _} _ {_} _ {_} _. +Global Arguments Pair {_ _ _ _ _} _ {_} _. +Global Arguments Return {_ _ _ _ _} _. +Global Arguments Abs {_ _ _ _ _ _} _ _. +Global Arguments extend {_ _ _ _} ctx {_} _ _. +Global Arguments remove {_ _ _ _} ctx {_} _. +Global Arguments lookup {_ _ _ _} ctx {_} _, {_ _ _ _} ctx _ _. +Global Arguments wff {_ _ _ _ _ _} ctx {t} _. +Global Arguments wf {_ _ _ _ _ _} ctx {t} _. +Global Arguments interp_genf {_ _ _ _ var _} _ _ _ _ _ _ {ctx t} _ _. +Global Arguments interpf {_ _ _ _ _ interp_op ctx t} _ _. +Global Arguments interp {_ _ _ _ _ interp_op ctx t} _ _. + +Notation "'slet' x := A 'in' b" := (LetIn _ x A%nexpr b%nexpr) : nexpr_scope. +Notation "'λn' x .. y , t" := (Abs x .. (Abs y t%nexpr) .. ) : nexpr_scope. +Notation "( x , y , .. , z )" := (Pair .. (Pair x%nexpr y%nexpr) .. z%nexpr) : nexpr_scope. diff --git a/src/Reflection/TestCase.v b/src/Reflection/TestCase.v index 5e671beb8..31ca20fec 100644 --- a/src/Reflection/TestCase.v +++ b/src/Reflection/TestCase.v @@ -1,3 +1,7 @@ +Require Import Coq.PArith.BinPos Coq.Lists.List. +Require Import Crypto.Reflection.Named.Syntax. +Require Import Crypto.Reflection.Named.Compile. +Require Import Crypto.Reflection.Named.RegisterAssign. Require Import Crypto.Reflection.Syntax. Require Export Crypto.Reflection.Reify. Require Import Crypto.Reflection.InputSyntax. @@ -77,7 +81,7 @@ Abort. Definition example_expr : Syntax.Expr base_type interp_base_type op (Arrow Tnat (Arrow Tnat (Tflat _ tnat))). Proof. - let x := Reify (fun z w => let x := 1 in let y := 1 in (let a := 1 in let '(c, d) := (2, 3) in a + x + (x + x) + (x + x) - (x + x) - a + c + d) + y + z + w)%nat in + let x := Reify (fun z w => let unused := 1 + 1 in let x := 1 in let y := 1 in (let a := 1 in let '(c, d) := (2, 3) in a + x + (x + x) + (x + x) - (x + x) - a + c + d) + y + z + w)%nat in exact x. Defined. @@ -144,3 +148,12 @@ End cse. Definition example_expr_simplified := Eval vm_compute in InlineConst (Linearize example_expr). Compute CSE example_expr_simplified. + +Definition example_expr_compiled + := Eval vm_compute in + match Named.Compile.compile (example_expr_simplified _) (List.map Pos.of_nat (seq 1 20)) as v return match v with Some _ => _ | _ => _ end with + | Some v => v + | None => True + end. + +Compute register_reassign Pos.eqb empty empty example_expr_compiled (Some 1%positive :: Some 2%positive :: None :: List.map (@Some _) (List.map Pos.of_nat (seq 3 20))). diff --git a/src/Specific/FancyMachine256/Barrett.v b/src/Specific/FancyMachine256/Barrett.v index 1683522e3..c96fcc37f 100644 --- a/src/Specific/FancyMachine256/Barrett.v +++ b/src/Specific/FancyMachine256/Barrett.v @@ -20,11 +20,6 @@ Section expression. Context (H : 0 <= m < 2^256). Let H' : 0 <= 250 <= 256. omega. Qed. Let H'' : 0 < 250. omega. Qed. - Let props' := ZLikeProperties_of_ArchitectureBoundedOps ops m H 250 H' H''. - Let ops' := (ZLikeOps_of_ArchitectureBoundedOps ops m 250). - Local Existing Instances props' ops'. - Local Notation fst' := (@fst fancy_machine.W fancy_machine.W). - Local Notation snd' := (@snd fancy_machine.W fancy_machine.W). Local Notation SmallT := (@ZBounded.SmallT (2 ^ 256) (2 ^ 250) m (@ZLikeOps_of_ArchitectureBoundedOps 128 ops m _)). Definition ldi' : load_immediate SmallT := _. @@ -38,21 +33,29 @@ Section expression. rewrite μ_good; apply μ_range. Qed. - Definition pre_f v - := (@barrett_reduce m b k μ offset m_pos base_pos μ_good offset_nonneg k_big_enough m_small m_large ops' props' μ' I μ'_eq (fst' v, snd' v)). + Let props' + ldi_modulus ldi_0 Hldi_modulus Hldi_0 + := ZLikeProperties_of_ArchitectureBoundedOps_Factored ops m ldi_modulus ldi_0 Hldi_modulus Hldi_0 H 250 H' H''. + + Definition pre_f' ldi_modulus ldi_0 ldi_μ Hldi_modulus Hldi_0 (Hldi_μ : ldi_μ = ldi' μ) + := (fun v => (@barrett_reduce m b k μ offset m_pos base_pos μ_good offset_nonneg k_big_enough m_small m_large _ (props' ldi_modulus ldi_0 Hldi_modulus Hldi_0) ldi_μ I (eq_trans (f_equal _ Hldi_μ) μ'_eq) (fst v, snd v))). + + Definition pre_f := pre_f' _ _ _ eq_refl eq_refl eq_refl. Local Arguments μ' / . Local Arguments ldi' / . + Local Arguments DoubleBounded.mul_double / . + Local Opaque Let_In Let_In_pf. Definition expression' := Eval simpl in - (fun v => proj1_sig (pre_f v)). + (fun v => pflet ldi_modulus, Hldi_modulus := fancy_machine.ldi m in + pflet ldi_μ, Hldi_μ := fancy_machine.ldi μ in + pflet ldi_0, Hldi_0 := fancy_machine.ldi 0 in + proj1_sig (pre_f' ldi_modulus ldi_0 ldi_μ Hldi_modulus Hldi_0 Hldi_μ v)). + Local Transparent Let_In Let_In_pf. Definition expression - := Eval cbv beta iota delta [expression' fst snd] in - fun v => let RegMod := fancy_machine.ldi m in - let RegMu := fancy_machine.ldi μ in - let RegZero := fancy_machine.ldi 0 in - expression' v. + := Eval cbv beta iota delta [expression' fst snd Let_In Let_In_pf] in expression'. Definition expression_eq v (H : 0 <= _ < _) : fancy_machine.decode (expression v) = _ := proj1 (proj2_sig (pre_f v) H). @@ -69,60 +72,88 @@ Section reflected. Definition rexpression_simple := Eval vm_compute in rexpression. + (*Compute DefaultRegisters rexpression_simple.*) + + Definition registers + := [RegMod; RegMuLow; x; xHigh; RegMod; RegMuLow; RegZero; tmp; q; qHigh; scratch+3; + SpecialCarryBit; q; SpecialCarryBit; qHigh; scratch+3; SpecialCarryBit; q; SpecialCarryBit; qHigh; tmp; + scratch+3; SpecialCarryBit; tmp; scratch+3; SpecialCarryBit; tmp; SpecialCarryBit; tmp; q; out]. + + Definition compiled_syntax + := Eval lazy in AssembleSyntax rexpression_simple registers. + Context (m μ : Z) (props : fancy_machine.arithmetic ops). Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple m μ (fst v) (snd v). + Let assembled_result (v : tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax m μ (fst v) (snd v). Theorem sanity : result = expression ops m μ. Proof. reflexivity. Qed. - Theorem correctness - (b : Z := 2) - (k : Z := 253) - (offset : Z := 3) - (H0 : 0 < m) - (H1 : μ = b^(2 * k) / m) - (H2 : 3 * m <= b^(k + offset)) - (H3 : b^(k - offset) <= m + 1) - (H4 : 0 <= m < 2^(k + offset)) - (H5 : 0 <= b^(2 * k) / m < b^(k + offset)) - (v : tuple fancy_machine.W 2) - (H6 : 0 <= decode v < b^(2 * k)) - : fancy_machine.decode (result v) = decode v mod m. + Theorem assembled_sanity : assembled_result = expression ops m μ. Proof. - rewrite sanity; destruct v. - apply expression_eq; assumption. + reflexivity. Qed. -End reflected. -Definition compiled_syntax - := Eval vm_compute in - (fun ops => AssembleSyntax ops (rexpression_simple _) (@RegMod :: @RegMuLow :: nil)%list). + Section correctness. + Let b : Z := 2. + Let k : Z := 253. + Let offset : Z := 3. + Context (H0 : 0 < m) + (H1 : μ = b^(2 * k) / m) + (H2 : 3 * m <= b^(k + offset)) + (H3 : b^(k - offset) <= m + 1) + (H4 : 0 <= m < 2^(k + offset)) + (H5 : 0 <= b^(2 * k) / m < b^(k + offset)) + (v : tuple fancy_machine.W 2) + (H6 : 0 <= decode v < b^(2 * k)). + Theorem correctness : fancy_machine.decode (result v) = decode v mod m. + Proof. + rewrite sanity; destruct v. + apply expression_eq; assumption. + Qed. + Theorem assembled_correctness : fancy_machine.decode (assembled_result v) = decode v mod m. + Proof. + rewrite assembled_sanity; destruct v. + apply expression_eq; assumption. + Qed. + End correctness. +End reflected. Print compiled_syntax. (* compiled_syntax = -fun (_ : fancy_machine.instructions (2 * 128)) (var : base_type -> Type) => -λ x x0 : var TW, -c.Rshi(x1, x0, x, 250), -c.Mul128(x2, c.UpperHalf(x1), c.UpperHalf(RegMuLow)), -c.Mul128(x3, c.UpperHalf(x1), c.LowerHalf(RegMuLow)), -c.Mul128(x4, c.LowerHalf(x1), c.LowerHalf(RegMuLow)), -c.Add(x6, x4, c.LeftShifted{x3, 128}), -c.Addc(x8, x2, c.RightShifted{x3, 128}), -c.Mul128(x9, c.UpperHalf(RegMuLow), c.LowerHalf(x1)), -c.Add(_, x6, c.LeftShifted{x9, 128}), -c.Addc(x13, x8, c.RightShifted{x9, 128}), -c.Mul128(x14, c.LowerHalf(x13), c.LowerHalf(RegMod)), -c.Mul128(x15, c.UpperHalf(x13), c.LowerHalf(RegMod)), -c.Add(x17, x14, c.LeftShifted{x15, 128}), -c.Mul128(x18, c.UpperHalf(RegMod), c.LowerHalf(x13)), -c.Add(x20, x17, c.LeftShifted{x18, 128}), -c.Sub(x22, x, x20), -c.Addm(x23, x22, RegZero), -c.Addm(x24, x23, RegZero), -Return x24 - : fancy_machine.instructions (2 * 128) -> forall var : base_type -> Type, syntax +fun ops : fancy_machine.instructions (2 * 128) => +λn RegMod RegMuLow x xHigh, +slet RegMod := RegMod in +slet RegMuLow := RegMuLow in +slet RegZero := ldi 0 in +c.Rshi(tmp, xHigh, x, 250), +c.Mul128(q, c.LowerHalf(tmp), c.LowerHalf(RegMuLow)), +c.Mul128(qHigh, c.UpperHalf(tmp), c.UpperHalf(RegMuLow)), +c.Mul128(scratch+3, c.UpperHalf(tmp), c.LowerHalf(RegMuLow)), +c.Add(q, q, c.LeftShifted{scratch+3, 128}), +c.Addc(qHigh, qHigh, c.RightShifted{scratch+3, 128}), +c.Mul128(scratch+3, c.UpperHalf(RegMuLow), c.LowerHalf(tmp)), +c.Add(q, q, c.LeftShifted{scratch+3, 128}), +c.Addc(qHigh, qHigh, c.RightShifted{scratch+3, 128}), +c.Mul128(tmp, c.LowerHalf(qHigh), c.LowerHalf(RegMod)), +c.Mul128(scratch+3, c.UpperHalf(qHigh), c.LowerHalf(RegMod)), +c.Add(tmp, tmp, c.LeftShifted{scratch+3, 128}), +c.Mul128(scratch+3, c.UpperHalf(RegMod), c.LowerHalf(qHigh)), +c.Add(tmp, tmp, c.LeftShifted{scratch+3, 128}), +c.Sub(tmp, x, tmp), +c.Addm(q, tmp, RegZero), +c.Addm(out, q, RegZero), +Return out + : forall ops : fancy_machine.instructions (2 * 128), + expr base_type + (fun v : base_type => + match v with + | TZ => Z + | Tbool => bool + | TW => let (W, _, _, _, _, _, _, _, _, _, _, _, _, _) := ops in W + end) op Register (TZ -> TZ -> TW -> TW -> Tbase TW)%ctype *) diff --git a/src/Specific/FancyMachine256/Core.v b/src/Specific/FancyMachine256/Core.v index 2392f7ede..64372b8f9 100644 --- a/src/Specific/FancyMachine256/Core.v +++ b/src/Specific/FancyMachine256/Core.v @@ -1,11 +1,17 @@ (** * A Fancy Machine with 256-bit registers *) Require Import Coq.Classes.RelationClasses Coq.Classes.Morphisms. -Require Export Coq.ZArith.ZArith. +Require Import Coq.PArith.BinPos Coq.micromega.Psatz. +Require Export Coq.ZArith.ZArith Coq.Lists.List. +Require Import Crypto.Util.Decidable. Require Export Crypto.BoundedArithmetic.Interface. Require Export Crypto.BoundedArithmetic.ArchitectureToZLike. Require Export Crypto.BoundedArithmetic.ArchitectureToZLikeProofs. Require Export Crypto.Util.Tuple. Require Import Crypto.Util.Option Crypto.Util.Sigma Crypto.Util.Prod. +Require Export Crypto.Reflection.Named.Syntax. +Require Import Crypto.Reflection.Named.DeadCodeElimination. +Require Import Crypto.Reflection.CountLets. +Require Import Crypto.Reflection.Named.ContextOn. Require Export Crypto.Reflection.Syntax. Require Import Crypto.Reflection.Linearize. Require Import Crypto.Reflection.Inline. @@ -13,6 +19,9 @@ Require Import Crypto.Reflection.CommonSubexpressionElimination. Require Export Crypto.Reflection.Reify. Require Export Crypto.Util.ZUtil. Require Export Crypto.Util.Notations. +Require Import Crypto.Util.ListUtil. +Require Export Crypto.Util.LetIn. +Export ListNotations. Open Scope Z_scope. Local Notation eta x := (fst x, snd x). @@ -25,6 +34,8 @@ Section reflection. Local Set Boolean Equality Schemes. Local Set Decidable Equality Schemes. Inductive base_type := TZ | Tbool | TW. + Global Instance dec_base_type : DecidableRel (@eq base_type) + := base_type_eq_dec. Definition interp_base_type (v : base_type) : Type := match v with | TZ => Z @@ -94,6 +105,27 @@ Section reflection. end. Definition CSE {t} e := @CSE base_type SConstT op_code base_type_beq SConstT_beq op_code_beq internal_base_type_dec_bl interp_base_type op symbolicify_const symbolicify_op t e (fun _ => nil). + + Inductive inline_option := opt_inline | opt_default | opt_noinline. + + Definition postprocess var t (e : @exprf base_type interp_base_type op var t) + : @inline_directive base_type interp_base_type op var t + := let opt : inline_option + := match e with + | Op _ _ OPshl _ => opt_inline + | Op _ _ OPshr _ => opt_inline + | _ => opt_default + end in + match opt with + | opt_noinline => no_inline e + | opt_default => default_inline e + | opt_inline => match t as t' return exprf _ _ _ t' -> inline_directive t' with + | Tbase _ => fun e => inline e + | _ => fun e => default_inline e + end e + end. + + Definition Inline {t} e := @InlineConstGen base_type interp_base_type op postprocess t e. End reflection. Ltac base_reify_op op op_head ::= @@ -121,7 +153,7 @@ Ltac base_reify_type T ::= Ltac Reify' e := Reify.Reify' base_type (interp_base_type _) op e. Ltac Reify e := let v := Reify.Reify base_type (interp_base_type _) op e in - constr:(CSE _ (InlineConst (Linearize v))). + constr:(Inline _ ((*CSE _*) (InlineConst (Linearize v)))). (*Ltac Reify_rhs := Reify.Reify_rhs base_type (interp_base_type _) op (interp_op _).*) (** ** Raw Syntax Trees *) @@ -132,226 +164,172 @@ Ltac Reify e := [string] identifiers and using them for pretty-printing... It might also be possible to verify this layer, too, by adding a partial interpretation function... *) -Section syn. - Context {var : base_type -> Type}. - Inductive syntax := - | RegPInv - | RegMod - | RegMuLow - | RegZero - | cConstZ : Z -> syntax - | cConstBool : bool -> syntax - | cLowerHalf : syntax -> syntax - | cUpperHalf : syntax -> syntax - | cLeftShifted : syntax -> Z -> syntax - | cRightShifted : syntax -> Z -> syntax - | cVar : var TW -> syntax - | cVarC : var Tbool -> syntax - | cBind : syntax -> (var TW -> syntax) -> syntax - | cBindCarry : syntax -> (var Tbool -> var TW -> syntax) -> syntax - | cMul128 : syntax -> syntax -> syntax - | cRshi : syntax -> syntax -> Z -> syntax - | cSelc : var Tbool -> syntax -> syntax -> syntax - | cAddc : var Tbool -> syntax -> syntax -> syntax - | cAddm : syntax -> syntax -> syntax - | cAdd : syntax -> syntax -> syntax - | cSub : syntax -> syntax -> syntax - | cPair : syntax -> syntax -> syntax - | cAbs {t} : (var t -> syntax) -> syntax - | cINVALID {T} (_ : T). -End syn. -Notation "'Return' x" := (cVar x) (at level 200). -Notation "'c.Mul128' ( x , A , B ) , b" := - (cBind (cMul128 A B) (fun x => b)) - (at level 200, b at level 200, format "'c.Mul128' ( x , A , B ) , '//' b"). +Local Set Decidable Equality Schemes. +Local Set Boolean Equality Schemes. + +Inductive Register := +| RegPInv | RegMod | RegMuLow | RegZero +| y | t1 | t2 | lo | hi | out | src1 | src2 | tmp | q | qHigh | x | xHigh +| SpecialCarryBit +| scratch | scratchplus (n : nat). + +Notation "'scratch+' n" := (scratchplus n) (format "'scratch+' n", at level 10). + +Definition pos_of_Register (r : Register) := + match r with + | RegPInv => 1 + | RegMod => 2 + | RegMuLow => 3 + | RegZero => 4 + | y => 5 + | t1 => 6 + | t2 => 7 + | lo => 8 + | hi => 9 + | out => 10 + | src1 => 11 + | src2 => 12 + | tmp => 13 + | q => 14 + | qHigh => 15 + | x => 16 + | xHigh => 17 + | SpecialCarryBit => 18 + | scratch => 19 + | scratchplus n => 19 + Pos.of_nat (S n) + end%positive. + +Lemma pos_of_Register_inj x y : pos_of_Register x = pos_of_Register y -> x = y. +Proof. + unfold pos_of_Register; repeat break_match; subst; + try rewrite Pos.add_cancel_l; try rewrite Nat2Pos.inj_iff; + try solve [ simpl; congruence | intros; exfalso; lia ]. +Qed. + +Global Instance RegisterContext {var : base_type -> Type} : Context Register var + := ContextOn pos_of_Register (RegisterAssign.pos_context var). + +Definition syntax {ops : fancy_machine.instructions (2 * 128)} + := Named.expr base_type (interp_base_type ops) op Register. + +Class wf_empty {ops} {var} {t} (e : Named.expr base_type (interp_base_type ops) op Register t) + := mk_wf_empty : @Named.wf base_type (interp_base_type ops) op Register var _ empty t e. +Global Hint Extern 0 (wf_empty _) => vm_compute; intros; constructor : typeclass_instances. + +(** Assemble a well-typed easily interpretable expression into a + syntax tree we can use for pretty-printing. *) +Section assemble. + Context {ops : fancy_machine.instructions (2 * 128)}. + + Definition AssembleSyntax' {t} (e : Expr base_type (interp_base_type _) op t) (ls : list Register) + : option (syntax t) + := CompileAndEliminateDeadCode e ls. + Definition AssembleSyntax {t} e ls (res := @AssembleSyntax' t e ls) + := match res return match res with None => _ | _ => _ end with + | Some v => v + | None => I + end. + + Definition dummy_registers (n : nat) : list Register + := List.map scratchplus (seq 0 n). + Definition DefaultRegisters {t} (e : Expr base_type (interp_base_type _) op t) : list Register + := dummy_registers (CountBinders e). + + Definition DefaultAssembleSyntax {t} e := @AssembleSyntax t e (DefaultRegisters e). + + Definition Interp {t} e {wf : wf_empty e} + := @Named.interp base_type (interp_base_type _) op Register _ (interp_op _) empty t e wf. +End assemble. + +Export Reflection.Named.Syntax. +Open Scope nexpr_scope. +Open Scope ctype_scope. +Open Scope type_scope. +Open Scope core_scope. + +Notation Return x := (Var x). +Notation ldi z := (Op OPldi (Const z%Z)). +Notation "'slet' x := A 'in' b" := (LetIn _ x (Op OPldi (Var A%nexpr)) b%nexpr) : nexpr_scope. +Notation "'c.Rshi' ( x , A , B , C ) , b" := + (LetIn _ x (Op OPshrd (Pair (Pair (Var A) (Var B)) (Const C%Z))) b) + (at level 200, b at level 200, format "'c.Rshi' ( x , A , B , C ) , '//' b"). +Notation "'c.Mul128' ( x , 'c.UpperHalf' ( A ) , 'c.UpperHalf' ( B ) ) , b" := + (LetIn _ x (Op OPmulhwhh (Pair (Var A) (Var B))) b) + (at level 200, b at level 200, format "'c.Mul128' ( x , 'c.UpperHalf' ( A ) , 'c.UpperHalf' ( B ) ) , '//' b"). +Notation "'c.Mul128' ( x , 'c.UpperHalf' ( A ) , 'c.LowerHalf' ( B ) ) , b" := + (LetIn _ x (Op OPmulhwhl (Pair (Var A) (Var B))) b) + (at level 200, b at level 200, format "'c.Mul128' ( x , 'c.UpperHalf' ( A ) , 'c.LowerHalf' ( B ) ) , '//' b"). +Notation "'c.Mul128' ( x , 'c.LowerHalf' ( A ) , 'c.LowerHalf' ( B ) ) , b" := + (LetIn _ x (Op OPmulhwll (Pair (Var A) (Var B))) b) + (at level 200, b at level 200, format "'c.Mul128' ( x , 'c.LowerHalf' ( A ) , 'c.LowerHalf' ( B ) ) , '//' b"). +Notation "'c.LeftShifted' { x , v }" := + (Op OPshl (Pair (Var x) (Const v%Z))) + (at level 200, format "'c.LeftShifted' { x , v }"). +Notation "'c.RightShifted' { x , v }" := + (Op OPshr (Pair (Var x) (Const v%Z))) + (at level 200, format "'c.RightShifted' { x , v }"). + +Notation "'c.Add' ( x , A , B ) , b" := + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair A B) (Const false))) b) + (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' b"). Notation "'c.Add' ( x , A , B ) , b" := - (cBindCarry (cAdd A B) (fun _ x => b)) + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair (Var A) B) (Const false))) b) (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' b"). Notation "'c.Add' ( x , A , B ) , b" := - (cBindCarry (cAdd (cVar A) B) (fun _ x => b)) + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair A (Var B)) (Const false))) b) (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" := - (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c A1 B1) (fun _ x1 => b))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" := - (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun _ x1 => b))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" := - (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c A1 B1) (fun _ x1 => b))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" := - (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun _ x1 => b))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" := - (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c A1 B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b)))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" := - (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c A1 B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b)))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" := - (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b)))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" := - (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b)))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b"). -Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" := - (cBindCarry (cAdd (cVar A) (cVar B)) (fun c x => cBindCarry (cAddc c (cVar A1) (cVar B1)) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b)))) - (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b"). +Notation "'c.Add' ( x , A , B ) , b" := + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair (Var A) (Var B)) (Const false))) b) + (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' b"). +Notation "'c.Addc' ( x , A , B ) , b" := + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair A B) (Var SpecialCarryBit))) b) + (at level 200, b at level 200, format "'c.Addc' ( x , A , B ) , '//' b"). +Notation "'c.Addc' ( x , A , B ) , b" := + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair (Var A) B) (Var SpecialCarryBit))) b) + (at level 200, b at level 200, format "'c.Addc' ( x , A , B ) , '//' b"). +Notation "'c.Addc' ( x , A , B ) , b" := + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair A (Var B)) (Var SpecialCarryBit))) b) + (at level 200, b at level 200, format "'c.Addc' ( x , A , B ) , '//' b"). +Notation "'c.Addc' ( x , A , B ) , b" := + (LetIn _ (pair SpecialCarryBit x) (Op OPadc (Pair (Pair (Var A) (Var B)) (Var SpecialCarryBit))) b) + (at level 200, b at level 200, format "'c.Addc' ( x , A , B ) , '//' b"). Notation "'c.Sub' ( x , A , B ) , b" := - (cBindCarry (cSub A B) (fun _ x => b)) + (LetIn _ (pair SpecialCarryBit x) (Op OPsubc (Pair (Pair A B) (Const false))) b) (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b"). Notation "'c.Sub' ( x , A , B ) , b" := - (cBindCarry (cSub (cVar A) B) (fun _ x => b)) + (LetIn _ (pair SpecialCarryBit x) (Op OPsubc (Pair (Pair (Var A) B) (Const false))) b) (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b"). Notation "'c.Sub' ( x , A , B ) , b" := - (cBindCarry (cSub A (cVar B)) (fun _ x => b)) + (LetIn _ (pair SpecialCarryBit x) (Op OPsubc (Pair (Pair A (Var B)) (Const false))) b) (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b"). Notation "'c.Sub' ( x , A , B ) , b" := - (cBindCarry (cSub (cVar A) (cVar B)) (fun _ x => b)) + (LetIn _ (pair SpecialCarryBit x) (Op OPsubc (Pair (Pair (Var A) (Var B)) (Const false))) b) (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b"). Notation "'c.Addm' ( x , A , B ) , b" := - (cBind (cAddm A B) (fun x => b)) + (LetIn _ x (Op OPaddm (Pair (Pair A B) (Var RegMod))) b) (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b"). Notation "'c.Addm' ( x , A , B ) , b" := - (cBind (cAddm A (cVar B)) (fun x => b)) + (LetIn _ x (Op OPaddm (Pair (Pair (Var A) B) (Var RegMod))) b) (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b"). Notation "'c.Addm' ( x , A , B ) , b" := - (cBind (cAddm (cVar A) B) (fun x => b)) + (LetIn _ x (Op OPaddm (Pair (Pair A (Var B)) (Var RegMod))) b) (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b"). Notation "'c.Addm' ( x , A , B ) , b" := - (cBind (cAddm (cVar A) (cVar B)) (fun x => b)) + (LetIn _ x (Op OPaddm (Pair (Pair (Var A) (Var B)) (Var RegMod))) b) (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b"). -Notation "'c.Rshi' ( x , A , B , C ) , b" := - (cBind (cRshi (cVar A) (cVar B) C) (fun x => b)) - (at level 200, b at level 200, format "'c.Rshi' ( x , A , B , C ) , '//' b"). - -Notation "'c.LowerHalf' ( x )" := - (cLowerHalf x) - (at level 200, format "'c.LowerHalf' ( x )"). -Notation "'c.LowerHalf' ( x )" := - (cLowerHalf (cVar x)) - (at level 200, format "'c.LowerHalf' ( x )"). -Notation "'c.UpperHalf' ( x )" := - (cUpperHalf x) - (at level 200, format "'c.UpperHalf' ( x )"). -Notation "'c.UpperHalf' ( x )" := - (cUpperHalf (cVar x)) - (at level 200, format "'c.UpperHalf' ( x )"). -Notation "'c.LeftShifted' { x , v }" := - (cLeftShifted x v) - (at level 200, format "'c.LeftShifted' { x , v }"). -Notation "'c.LeftShifted' { x , v }" := - (cLeftShifted (cVar x) v) - (at level 200, format "'c.LeftShifted' { x , v }"). -Notation "'c.RightShifted' { x , v }" := - (cRightShifted x v) - (at level 200, format "'c.RightShifted' { x , v }"). -Notation "'c.RightShifted' { x , v }" := - (cRightShifted (cVar x) v) - (at level 200, format "'c.RightShifted' { x , v }"). -Notation "'λ' x .. y , t" := (cAbs (fun x => .. (cAbs (fun y => t)) ..)) - (at level 200, x binder, y binder, right associativity). - -Definition Syntax := forall var, @syntax var. - -(** Assemble a well-typed easily interpretable expression into a - syntax tree we can use for pretty-printing. *) -Section assemble. - Context (ops : fancy_machine.instructions (2 * 128)). - - Section with_var. - Context {var : base_type -> Type}. - - Fixpoint assemble_syntax_const - {t} - : interp_flat_type_gen (interp_base_type _) t -> @syntax var - := match t return interp_flat_type_gen (interp_base_type _) t -> @syntax var with - | Tbase TZ => cConstZ - | Tbase Tbool => cConstBool - | Tbase t => fun _ => cINVALID t - | Prod A B => fun xy => cPair (@assemble_syntax_const A (fst xy)) - (@assemble_syntax_const B (snd xy)) - end. - - Definition assemble_syntaxf_step - (assemble_syntaxf : forall {t} (v : @Syntax.exprf base_type (interp_base_type _) op (fun _ => @syntax var) t), @syntax var) - {t} (v : @Syntax.exprf base_type (interp_base_type _) op (fun _ => @syntax var) t) : @syntax var. - Proof. - refine match v return @syntax var with - | Syntax.Const t x => assemble_syntax_const x - | Syntax.Var _ x => x - | Syntax.Op t1 tR op args - => let v := @assemble_syntaxf t1 args in - (* handle both associativities of pairs in 3-ary - operators, in case we ever change the - associativity *) - match op, v with - | OPldi , cConstZ 0 => RegZero - | OPldi , cConstZ v => cINVALID v - | OPldi , RegZero => RegZero - | OPldi , RegMod => RegMod - | OPldi , RegMuLow => RegMuLow - | OPldi , RegPInv => RegPInv - | OPshrd , cPair x (cPair y (cConstZ n)) => cRshi x y n - | OPshrd , cPair (cPair x y) (cConstZ n) => cRshi x y n - | OPshl , cPair w (cConstZ n) => cLeftShifted w n - | OPshr , cPair w (cConstZ n) => cRightShifted w n - | OPmkl , _ => cINVALID op - | OPadc , cPair (cPair x y) (cVarC c) => cAddc c x y - | OPadc , cPair x (cPair y (cVarC c)) => cAddc c x y - | OPadc , cPair (cPair x y) (cConstBool false) => cAdd x y - | OPadc , cPair x (cPair y (cConstBool false)) => cAdd x y - | OPsubc , cPair (cPair x y) (cConstBool false) => cSub x y - | OPsubc , cPair x (cPair y (cConstBool false)) => cSub x y - | OPmulhwll, cPair x y => cMul128 (cLowerHalf x) (cLowerHalf y) - | OPmulhwhl, cPair x y => cMul128 (cUpperHalf x) (cLowerHalf y) - | OPmulhwhh, cPair x y => cMul128 (cUpperHalf x) (cUpperHalf y) - | OPselc , cPair (cVarC c) (cPair x y) => cSelc c x y - | OPselc , cPair (cPair (cVarC c) x) y => cSelc c x y - | OPaddm , cPair x (cPair y RegMod) => cAddm x y - | OPaddm , cPair (cPair x y) RegMod => cAddm x y - | _, _ => cINVALID op - end - | Syntax.LetIn tx ex _ eC - => let ex' := @assemble_syntaxf _ ex in - let eC' := fun x => @assemble_syntaxf _ (eC x) in - let special := match ex' with - | RegZero as ex'' | RegMuLow as ex'' | RegMod as ex'' | RegPInv as ex'' - | cUpperHalf _ as ex'' | cLowerHalf _ as ex'' - | cLeftShifted _ _ as ex'' - | cRightShifted _ _ as ex'' - => Some ex'' - | _ => None - end in - match special, tx return (interp_flat_type_gen _ tx -> _) -> _ with - | Some x, Tbase _ => fun eC' => eC' x - | _, Tbase TW - => fun eC' => cBind ex' (fun x => eC' (cVar x)) - | _, Prod (Tbase Tbool) (Tbase TW) - => fun eC' => cBindCarry ex' (fun c x => eC' (cVarC c, cVar x)) - | _, _ - => fun _ => cINVALID (fun x : Prop => x) - end eC' - | Syntax.Pair _ ex _ ey - => cPair (@assemble_syntaxf _ ex) (@assemble_syntaxf _ ey) - end. - Defined. - - Fixpoint assemble_syntaxf {t} v {struct v} : @syntax var - := @assemble_syntaxf_step (@assemble_syntaxf) t v. - Fixpoint assemble_syntax {t} (v : @Syntax.expr base_type (interp_base_type _) op (fun _ => @syntax var) t) (args : list (@syntax var)) {struct v} - : @syntax var - := match v, args return @syntax var with - | Syntax.Return _ x, _ => assemble_syntaxf x - | Syntax.Abs _ _ f, nil => cAbs (fun x => @assemble_syntax _ (f (cVar x)) args) - | Syntax.Abs _ _ f, cons v vs => @assemble_syntax _ (f v) vs - end. - End with_var. - - Definition AssembleSyntax {t} (v : Syntax.Expr _ _ _ t) (args : list Syntax) : Syntax - := fun var => @assemble_syntax var t (v _) (List.map (fun f => f var) args). -End assemble. +Notation "'c.Selc' ( x , A , B ) , b" := + (LetIn _ x (Op OPselc (Pair (Pair (Var SpecialCarryBit) A) B)) b) + (at level 200, b at level 200, format "'c.Selc' ( x , A , B ) , '//' b"). +Notation "'c.Selc' ( x , A , B ) , b" := + (LetIn _ x (Op OPselc (Pair (Pair (Var SpecialCarryBit) (Var A)) B)) b) + (at level 200, b at level 200, format "'c.Selc' ( x , A , B ) , '//' b"). +Notation "'c.Selc' ( x , A , B ) , b" := + (LetIn _ x (Op OPselc (Pair (Pair (Var SpecialCarryBit) A) (Var B))) b) + (at level 200, b at level 200, format "'c.Selc' ( x , A , B ) , '//' b"). +Notation "'c.Selc' ( x , A , B ) , b" := + (LetIn _ x (Op OPselc (Pair (Pair (Var SpecialCarryBit) (Var A)) (Var B))) b) + (at level 200, b at level 200, format "'c.Selc' ( x , A , B ) , '//' b"). diff --git a/src/Specific/FancyMachine256/Montgomery.v b/src/Specific/FancyMachine256/Montgomery.v index a9a50f773..bcb25ffd6 100644 --- a/src/Specific/FancyMachine256/Montgomery.v +++ b/src/Specific/FancyMachine256/Montgomery.v @@ -6,16 +6,21 @@ Section expression. Context (ops : fancy_machine.instructions (2 * 128)) (props : fancy_machine.arithmetic ops) (modulus : Z) (m' : Z) (Hm : modulus <> 0) (H : 0 <= modulus < 2^256) (Hm' : 0 <= m' < 2^256). Let H' : 0 <= 256 <= 256. omega. Qed. Let H'' : 0 < 256. omega. Qed. - Let props' := ZLikeProperties_of_ArchitectureBoundedOps ops modulus H 256 H' H''. - Let ops' := (ZLikeOps_of_ArchitectureBoundedOps ops modulus 256). - Local Notation fst' := (@fst fancy_machine.W fancy_machine.W). - Local Notation snd' := (@snd fancy_machine.W fancy_machine.W). Definition ldi' : load_immediate (@ZBounded.SmallT (2 ^ 256) (2 ^ 256) modulus (@ZLikeOps_of_ArchitectureBoundedOps 128 ops modulus 256)) := _. Let isldi : is_load_immediate ldi' := _. - Definition pre_f := (fun v => (reduce_via_partial (2^256) modulus (props := props') (ldi' m') I Hm (fst' v, snd' v))). - Definition f := (fun v => proj1_sig (pre_f v)). + Let props' + ldi_modulus ldi_0 Hldi_modulus Hldi_0 + := ZLikeProperties_of_ArchitectureBoundedOps_Factored ops modulus ldi_modulus ldi_0 Hldi_modulus Hldi_0 H 256 H' H''. + Definition pre_f' ldi_modulus ldi_0 Hldi_modulus Hldi_0 lm' + := (fun v => (reduce_via_partial (2^256) modulus (props := props' ldi_modulus ldi_0 Hldi_modulus Hldi_0) lm' I Hm (fst v, snd v))). + Definition pre_f := pre_f' _ _ eq_refl eq_refl (ldi' m'). + + Definition f := (fun v => pflet ldi_modulus, Hldi_modulus := ldi' modulus in + dlet lm' := ldi' m' in + pflet ldi_0, Hldi_0 := ldi' 0 in + proj1_sig (pre_f' ldi_modulus ldi_0 Hldi_modulus Hldi_0 lm' v)). Local Arguments proj1_sig _ _ !_ / . Local Arguments ZBounded.CarryAdd / . @@ -24,18 +29,18 @@ Section expression. Local Arguments ZLikeOps_of_ArchitectureBoundedOps / . Local Arguments ZBounded.DivBy_SmallBound / . Local Arguments f / . - Local Arguments pre_f / . + Local Arguments pre_f' / . Local Arguments ldi' / . Local Arguments reduce_via_partial / . + Local Arguments DoubleBounded.mul_double / . + Local Opaque Let_In Let_In_pf. Definition expression' := Eval simpl in f. + Local Transparent Let_In Let_In_pf. Definition expression - := Eval cbv beta delta [expression' fst snd] in - fun v => let RegMod := fancy_machine.ldi modulus in - let RegPInv := fancy_machine.ldi m' in - let RegZero := fancy_machine.ldi 0 in - expression' v. + := Eval cbv beta delta [expression' fst snd Let_In Let_In_pf] in expression'. + Definition expression_eq v : fancy_machine.decode (expression v) = _ := proj1 (proj2_sig (pre_f v) I). Definition expression_correct @@ -43,7 +48,7 @@ Section expression. v Hv : fancy_machine.decode (expression v) = _ - := @ZBounded.reduce_via_partial_correct (2^256) modulus _ props' (ldi' m') I Hm R' HR0 HR1 v I Hv. + := @ZBounded.reduce_via_partial_correct (2^256) modulus _ (props' _ _ eq_refl eq_refl) (ldi' m') I Hm R' HR0 HR1 (fst v, snd v) I Hv. End expression. Section reflected. @@ -57,62 +62,97 @@ Section reflected. Definition rexpression_simple := Eval vm_compute in rexpression. + (*Compute DefaultRegisters rexpression_simple.*) + + Definition registers + := [RegMod; RegPInv; lo; hi; RegMod; RegPInv; RegZero; y; t1; SpecialCarryBit; y; + t1; SpecialCarryBit; y; t1; t2; scratch+3; SpecialCarryBit; t1; SpecialCarryBit; t2; + scratch+3; SpecialCarryBit; t1; SpecialCarryBit; t2; SpecialCarryBit; lo; SpecialCarryBit; hi; y; + SpecialCarryBit; lo; lo]. + + Definition compiled_syntax + := Eval lazy in AssembleSyntax rexpression_simple registers. + Context (modulus m' : Z) (props : fancy_machine.arithmetic ops). Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple modulus m' (fst v) (snd v). + Let assembled_result (v : tuple fancy_machine.W 2) : fancy_machine.W := Core.Interp compiled_syntax modulus m' (fst v) (snd v). + Theorem sanity : result = expression ops modulus m'. Proof. reflexivity. Qed. + Theorem assembled_sanity : assembled_result = expression ops modulus m'. + Proof. + reflexivity. + Qed. + Local Infix "≡₂₅₆" := (Z.equiv_modulo (2^256)). Local Infix "≡" := (Z.equiv_modulo modulus). - Theorem correctness - R' (* modular inverse of 2^256 *) - (H0 : modulus <> 0) - (H1 : 0 <= modulus < 2^256) - (H2 : 0 <= m' < 2^256) - (H3 : 2^256 * R' ≡ 1) - (H4 : modulus * m' ≡₂₅₆ -1) - (v : tuple fancy_machine.W 2) - (H5 : 0 <= decode v <= 2^256 * modulus) - : fancy_machine.decode (result v) = (decode v * R') mod modulus. - Proof. - replace m' with (fancy_machine.decode (fancy_machine.ldi m')) in H4 - by (apply decode_load_immediate; trivial; exact _). - rewrite sanity; destruct v; apply expression_correct; assumption. - Qed. + Section correctness. + Context R' (* modular inverse of 2^256 *) + (H0 : modulus <> 0) + (H1 : 0 <= modulus < 2^256) + (H2 : 0 <= m' < 2^256) + (H3 : 2^256 * R' ≡ 1) + (H4 : modulus * m' ≡₂₅₆ -1) + (v : tuple fancy_machine.W 2) + (H5 : 0 <= decode v <= 2^256 * modulus). + Theorem correctness + : fancy_machine.decode (result v) = (decode v * R') mod modulus. + Proof. + replace m' with (fancy_machine.decode (fancy_machine.ldi m')) + in H4 + by (apply decode_load_immediate; trivial; exact _). + rewrite sanity; destruct v; apply expression_correct; assumption. + Qed. + Theorem assembled_correctness + : fancy_machine.decode (assembled_result v) = (decode v * R') mod modulus. + Proof. + replace m' with (fancy_machine.decode (fancy_machine.ldi m')) + in H4 + by (apply decode_load_immediate; trivial; exact _). + rewrite assembled_sanity; destruct v; apply expression_correct; assumption. + Qed. + End correctness. End reflected. -Definition compiled_syntax - := Eval vm_compute in - (fun ops => AssembleSyntax ops (rexpression_simple _) (@RegMod :: @RegPInv :: nil)%list). - Print compiled_syntax. (* compiled_syntax = -fun (_ : fancy_machine.instructions (2 * 128)) (var : base_type -> Type) => -λ x x0 : var TW, -c.Mul128(x1, c.LowerHalf(x), c.LowerHalf(RegPInv)), -c.Mul128(x2, c.UpperHalf(x), c.LowerHalf(RegPInv)), -c.Add(x4, x1, c.LeftShifted{x2, 128}), -c.Mul128(x5, c.UpperHalf(RegPInv), c.LowerHalf(x)), -c.Add(x7, x4, c.LeftShifted{x5, 128}), -c.Mul128(x8, c.UpperHalf(x7), c.UpperHalf(RegMod)), -c.Mul128(x9, c.UpperHalf(x7), c.LowerHalf(RegMod)), -c.Mul128(x10, c.LowerHalf(x7), c.LowerHalf(RegMod)), -c.Add(x12, x10, c.LeftShifted{x9, 128}), -c.Addc(x14, x8, c.RightShifted{x9, 128}), -c.Mul128(x15, c.UpperHalf(RegMod), c.LowerHalf(x7)), -c.Add(x17, x12, c.LeftShifted{x15, 128}), -c.Addc(x19, x14, c.RightShifted{x15, 128}), -c.Add(_, x, x17), -c.Addc(x23, x0, x19), -c.Selc(x24, RegMod, RegZero), -c.Sub(x26, x23, x24), -c.Addm(x27, x26, RegZero), -Return x27 - : fancy_machine.instructions (2 * 128) -> forall var : base_type -> Type, syntax +fun ops : fancy_machine.instructions (2 * 128) => +λn RegMod RegPInv lo hi, +slet RegMod := RegMod in +slet RegPInv := RegPInv in +slet RegZero := ldi 0 in +c.Mul128(y, c.LowerHalf(lo), c.LowerHalf(RegPInv)), +c.Mul128(t1, c.UpperHalf(lo), c.LowerHalf(RegPInv)), +c.Add(y, y, c.LeftShifted{t1, 128}), +c.Mul128(t1, c.UpperHalf(RegPInv), c.LowerHalf(lo)), +c.Add(y, y, c.LeftShifted{t1, 128}), +c.Mul128(t1, c.LowerHalf(y), c.LowerHalf(RegMod)), +c.Mul128(t2, c.UpperHalf(y), c.UpperHalf(RegMod)), +c.Mul128(scratch+3, c.UpperHalf(y), c.LowerHalf(RegMod)), +c.Add(t1, t1, c.LeftShifted{scratch+3, 128}), +c.Addc(t2, t2, c.RightShifted{scratch+3, 128}), +c.Mul128(scratch+3, c.UpperHalf(RegMod), c.LowerHalf(y)), +c.Add(t1, t1, c.LeftShifted{scratch+3, 128}), +c.Addc(t2, t2, c.RightShifted{scratch+3, 128}), +c.Add(lo, lo, t1), +c.Addc(hi, hi, t2), +c.Selc(y, RegMod, RegZero), +c.Sub(lo, hi, y), +c.Addm(lo, lo, RegZero), +Return lo + : forall ops : fancy_machine.instructions (2 * 128), + expr base_type + (fun v : base_type => + match v with + | TZ => Z + | Tbool => bool + | TW => let (W, _, _, _, _, _, _, _, _, _, _, _, _, _) := ops in W + end) op Register (TZ -> TZ -> TW -> TW -> Tbase TW)%ctype *) |