diff options
author | Jason Gross <jgross@mit.edu> | 2017-10-13 22:06:48 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2017-10-18 23:01:29 -0400 |
commit | 28144db80f4b4b1a852af2504b18d7cc0e43c4ba (patch) | |
tree | 0dd5a54d287e1949002a0e8d7794b3f545704695 | |
parent | 17a4dec48a6cfe9dc83dd1670dec6e700ce7cac0 (diff) |
Fix synthesis output record
The bounds checking on the reflective pipeline won't go through without
carries, so when synthesizing word-based operations, always carry.
-rw-r--r-- | src/Specific/Framework/OutputType.v | 223 |
1 files changed, 139 insertions, 84 deletions
diff --git a/src/Specific/Framework/OutputType.v b/src/Specific/Framework/OutputType.v index 88653d980..50739c824 100644 --- a/src/Specific/Framework/OutputType.v +++ b/src/Specific/Framework/OutputType.v @@ -7,6 +7,7 @@ Require Import Crypto.Compilers.Tuple. Require Import Crypto.Compilers.ExprInversion. Require Import Crypto.Compilers.Z.Syntax.Util. Require Import Crypto.Compilers.Z.Syntax. +Require Import Crypto.Compilers.Tuple. Require Import Crypto.Specific.Framework.RawCurveParameters. Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Base. Require Import Crypto.Util.Notations. @@ -16,94 +17,148 @@ Local Coercion Z.to_nat : Z >-> nat. Local Notation interp_flat_type := (interp_flat_type interp_base_type). Section gen. - Context (curve : RawCurveParameters.CurveParameters) - (b : base_type). - - Definition m := Z.to_pos (curve.(s) - Associational.eval curve.(c))%Z. - Definition rT := ((Tbase b)^curve.(sz))%ctype. - Definition T' := (interp_flat_type rT). - Definition RT := (Unit -> rT)%ctype. - Definition wt := (wt_gen m curve.(sz)). - Definition encode : F m -> Expr RT - := fun v var - => Abs - (fun _ - => SmartPairf - (flat_interp_untuple - (T:=Tbase _) - (Tuple.map - (fun v => Op (OpConst v) TT) - (@Positional.Fencode wt curve.(sz) m modulo div v)))). - Definition decode : T' -> F m - := fun v => @Positional.Fdecode - wt (curve.(sz)) m - (Tuple.map interpToZ (flat_interp_tuple (T:=Tbase _) v)). - - Record SynthesisOutputOn := + Context (curve : RawCurveParameters.CurveParameters). + + Section gen_base_type. + Context (b : base_type). + + Definition m := Z.to_pos (curve.(s) - Associational.eval curve.(c))%Z. + Definition rT := ((Tbase b)^curve.(sz))%ctype. + Definition T' := (interp_flat_type rT). + Definition RT := (Unit -> rT)%ctype. + Definition wt := (wt_gen m curve.(sz)). + Definition encode : F m -> Expr RT + := fun v var + => Abs + (fun _ + => SmartPairf + (flat_interp_untuple + (T:=Tbase _) + (Tuple.map + (fun v => Op (OpConst v) TT) + (@Positional.Fencode wt curve.(sz) m modulo div v)))). + Definition decode : T' -> F m + := fun v => @Positional.Fdecode + wt (curve.(sz)) m + (Tuple.map interpToZ (flat_interp_tuple (T:=Tbase _) v)). + End gen_base_type. + + Local Notation TW := (TWord (Z.log2_up curve.(bitwidth))). + Local Notation RTZ := (RT TZ). + Local Notation rTZ := (rT TZ). + Local Notation RTW := (RT TW). + Local Notation rTW := (rT TW). + + Record SynthesisOutput := { - zero : Expr RT; - one : Expr RT; - add : Expr (rT * rT -> rT); (* does not include carry *) - sub : Expr (rT * rT -> rT); (* does not include carry *) - mul : Expr (rT * rT -> rT); (* includes carry *) - square : Expr (rT -> rT); (* includes carry *) - opp : Expr (rT -> rT); (* does not include carry *) - carry : Expr (rT -> rT); - carry_add : Expr (rT * rT -> rT) - := (carry ∘ add)%expr; - carry_sub : Expr (rT * rT -> rT) - := (carry ∘ sub)%expr; - carry_opp : Expr (rT -> rT) - := (carry ∘ opp)%expr; - - P : T' -> Prop; - - encode_valid : forall v, _; - encode_sig := fun v => exist P (Interp (encode v) tt) (encode_valid v); - encode_correct : forall v, decode (Interp (encode v) tt) = v; - - decode_sig := fun v : sig P => decode (proj1_sig v); - - zero_correct : zero = encode 0%F; (* which equality to use here? *) - one_correct : one = encode 1%F; (* which equality to use here? *) - zero_sig := encode_sig 0%F; - one_sig := encode_sig 1%F; - - opp_valid : forall x, P x -> P (Interp carry_opp x); - opp_sig := fun x => exist P _ (@opp_valid (proj1_sig x) (proj2_sig x)); - - add_valid : forall x y, P x -> P y -> P (Interp carry_add (x, y)); - add_sig := fun x y => exist P _ (@add_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); - - sub_valid : forall x y, P x -> P y -> P (Interp carry_sub (x, y)); - sub_sig := fun x y => exist P _ (@sub_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); - - mul_valid : forall x y, P x -> P y -> P (Interp mul (x, y)); - mul_sig := fun x y => exist P _ (@mul_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); - - square_correct : forall x, P x -> Interp square x = Interp mul (x, x); - - T := { v : T' | P v }; - eqT : T -> T -> Prop - := fun x y => eq (decode (proj1_sig x)) (decode (proj1_sig y)); - ring : @Hierarchy.ring - T eqT zero_sig one_sig opp_sig add_sig sub_sig mul_sig; - encode_homomorphism + zeroZ : Expr RTZ; + oneZ : Expr RTZ; + addZ : Expr (rTZ * rTZ -> rTZ); (* does not include carry *) + subZ : Expr (rTZ * rTZ -> rTZ); (* does not include carry *) + carry_mulZ : Expr (rTZ * rTZ -> rTZ); (* includes carry *) + carry_squareZ : Expr (rTZ -> rTZ); (* includes carry *) + oppZ : Expr (rTZ -> rTZ); (* does not include carry *) + carryZ : Expr (rTZ -> rTZ); + carry_addZ : Expr (rTZ * rTZ -> rTZ) + := (carryZ ∘ addZ)%expr; + carry_subZ : Expr (rTZ * rTZ -> rTZ) + := (carryZ ∘ subZ)%expr; + carry_oppZ : Expr (rTZ -> rTZ) + := (carryZ ∘ oppZ)%expr; + + zeroW : Expr RTW; + oneW : Expr RTW; + carry_addW : Expr (rTW * rTW -> rTW); (* includes carry *) + carry_subW : Expr (rTW * rTW -> rTW); (* includes carry *) + carry_mulW : Expr (rTW * rTW -> rTW); (* includes carry *) + carry_squareW : Expr (rTW -> rTW); (* includes carry *) + carry_oppW : Expr (rTW -> rTW); (* does not include carry *) + + PZ : T' TZ -> Prop; + PW : T' TW -> Prop + := fun v => PZ (tuple_map (A:=Tbase TW) (B:=Tbase TZ) (SmartVarfMap (@interpToZ)) v); + + encodeZ_valid : forall v, _; + encodeZ_sig := fun v => exist PZ (Interp (encode TZ v) tt) (encodeZ_valid v); + encodeZ_correct : forall v, decode TZ (Interp (encode TZ v) tt) = v; + + decodeZ_sig := fun v : sig PZ => decode TZ (proj1_sig v); + + zeroZ_correct : zeroZ = encode _ 0%F; (* which equality to use here? *) + oneZ_correct : oneZ = encode _ 1%F; (* which equality to use here? *) + zeroZ_sig := encodeZ_sig 0%F; + oneZ_sig := encodeZ_sig 1%F; + + oppZ_valid : forall x, PZ x -> PZ (Interp carry_oppZ x); + oppZ_sig := fun x => exist PZ _ (@oppZ_valid (proj1_sig x) (proj2_sig x)); + + addZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_addZ (x, y)); + addZ_sig := fun x y => exist PZ _ (@addZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); + + subZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_subZ (x, y)); + subZ_sig := fun x y => exist PZ _ (@subZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); + + mulZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_mulZ (x, y)); + mulZ_sig := fun x y => exist PZ _ (@mulZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); + + squareZ_correct : forall x, PZ x -> Interp carry_squareZ x = Interp carry_mulZ (x, x); + + + encodeW_valid : forall v, _; + encodeW_sig := fun v => exist PW (Interp (encode TW v) tt) (encodeW_valid v); + encodeW_correct : forall v, decode TW (Interp (encode TW v) tt) = v; + + decodeW_sig := fun v : sig PW => decode TW (proj1_sig v); + + zeroW_correct : zeroW = encode _ 0%F; (* which equality to use here? *) + oneW_correct : oneW = encode _ 1%F; (* which equality to use here? *) + zeroW_sig := encodeW_sig 0%F; + oneW_sig := encodeW_sig 1%F; + + oppW_valid : forall x, PW x -> PW (Interp carry_oppW x); + oppW_sig := fun x => exist PW _ (@oppW_valid (proj1_sig x) (proj2_sig x)); + + addW_valid : forall x y, PW x -> PW y -> PW (Interp carry_addW (x, y)); + addW_sig := fun x y => exist PW _ (@addW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); + + subW_valid : forall x y, PW x -> PW y -> PW (Interp carry_subW (x, y)); + subW_sig := fun x y => exist PW _ (@subW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); + + mulW_valid : forall x y, PW x -> PW y -> PW (Interp carry_mulW (x, y)); + mulW_sig := fun x y => exist PW _ (@mulW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y)); + + squareW_correct : forall x, PW x -> Interp carry_squareW x = Interp carry_mulW (x, x); + + T_Z := { v : T' TZ | PZ v }; + eqTZ : T_Z -> T_Z -> Prop + := fun x y => eq (decode _ (proj1_sig x)) (decode _ (proj1_sig y)); + ringZ : @Hierarchy.ring + T_Z eqTZ zeroZ_sig oneZ_sig oppZ_sig addZ_sig subZ_sig mulZ_sig; + encodeZ_homomorphism + : @Ring.is_homomorphism + (F m) eq 1%F F.add F.mul + T_Z eqTZ oneZ_sig addZ_sig mulZ_sig + encodeZ_sig; + decodeZ_homomorphism : @Ring.is_homomorphism + T_Z eqTZ oneZ_sig addZ_sig mulZ_sig (F m) eq 1%F F.add F.mul - T eqT one_sig add_sig mul_sig - encode_sig; - decode_homomorphism + decodeZ_sig; + + T_W := { v : T' TW | PW v }; + eqTW : T_W -> T_W -> Prop + := fun x y => eq (decode _ (proj1_sig x)) (decode _ (proj1_sig y)); + ringW : @Hierarchy.ring + T_W eqTW zeroW_sig oneW_sig oppW_sig addW_sig subW_sig mulW_sig; + encodeW_homomorphism : @Ring.is_homomorphism - T eqT one_sig add_sig mul_sig (F m) eq 1%F F.add F.mul - decode_sig + T_W eqTW oneW_sig addW_sig mulW_sig + encodeW_sig; + decodeW_homomorphism + : @Ring.is_homomorphism + T_W eqTW oneW_sig addW_sig mulW_sig + (F m) eq 1%F F.add F.mul + decodeW_sig }. End gen. - - -Record SynthesisOutput (curve : RawCurveParameters.CurveParameters) := - { - onZ : SynthesisOutputOn curve TZ; - onWord : SynthesisOutputOn curve (TWord (Z.log2_up curve.(bitwidth))) - }. |