aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2017-10-13 22:06:48 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2017-10-18 23:01:29 -0400
commit28144db80f4b4b1a852af2504b18d7cc0e43c4ba (patch)
tree0dd5a54d287e1949002a0e8d7794b3f545704695
parent17a4dec48a6cfe9dc83dd1670dec6e700ce7cac0 (diff)
Fix synthesis output record
The bounds checking on the reflective pipeline won't go through without carries, so when synthesizing word-based operations, always carry.
-rw-r--r--src/Specific/Framework/OutputType.v223
1 files changed, 139 insertions, 84 deletions
diff --git a/src/Specific/Framework/OutputType.v b/src/Specific/Framework/OutputType.v
index 88653d980..50739c824 100644
--- a/src/Specific/Framework/OutputType.v
+++ b/src/Specific/Framework/OutputType.v
@@ -7,6 +7,7 @@ Require Import Crypto.Compilers.Tuple.
Require Import Crypto.Compilers.ExprInversion.
Require Import Crypto.Compilers.Z.Syntax.Util.
Require Import Crypto.Compilers.Z.Syntax.
+Require Import Crypto.Compilers.Tuple.
Require Import Crypto.Specific.Framework.RawCurveParameters.
Require Import Crypto.Specific.Framework.ArithmeticSynthesis.Base.
Require Import Crypto.Util.Notations.
@@ -16,94 +17,148 @@ Local Coercion Z.to_nat : Z >-> nat.
Local Notation interp_flat_type := (interp_flat_type interp_base_type).
Section gen.
- Context (curve : RawCurveParameters.CurveParameters)
- (b : base_type).
-
- Definition m := Z.to_pos (curve.(s) - Associational.eval curve.(c))%Z.
- Definition rT := ((Tbase b)^curve.(sz))%ctype.
- Definition T' := (interp_flat_type rT).
- Definition RT := (Unit -> rT)%ctype.
- Definition wt := (wt_gen m curve.(sz)).
- Definition encode : F m -> Expr RT
- := fun v var
- => Abs
- (fun _
- => SmartPairf
- (flat_interp_untuple
- (T:=Tbase _)
- (Tuple.map
- (fun v => Op (OpConst v) TT)
- (@Positional.Fencode wt curve.(sz) m modulo div v)))).
- Definition decode : T' -> F m
- := fun v => @Positional.Fdecode
- wt (curve.(sz)) m
- (Tuple.map interpToZ (flat_interp_tuple (T:=Tbase _) v)).
-
- Record SynthesisOutputOn :=
+ Context (curve : RawCurveParameters.CurveParameters).
+
+ Section gen_base_type.
+ Context (b : base_type).
+
+ Definition m := Z.to_pos (curve.(s) - Associational.eval curve.(c))%Z.
+ Definition rT := ((Tbase b)^curve.(sz))%ctype.
+ Definition T' := (interp_flat_type rT).
+ Definition RT := (Unit -> rT)%ctype.
+ Definition wt := (wt_gen m curve.(sz)).
+ Definition encode : F m -> Expr RT
+ := fun v var
+ => Abs
+ (fun _
+ => SmartPairf
+ (flat_interp_untuple
+ (T:=Tbase _)
+ (Tuple.map
+ (fun v => Op (OpConst v) TT)
+ (@Positional.Fencode wt curve.(sz) m modulo div v)))).
+ Definition decode : T' -> F m
+ := fun v => @Positional.Fdecode
+ wt (curve.(sz)) m
+ (Tuple.map interpToZ (flat_interp_tuple (T:=Tbase _) v)).
+ End gen_base_type.
+
+ Local Notation TW := (TWord (Z.log2_up curve.(bitwidth))).
+ Local Notation RTZ := (RT TZ).
+ Local Notation rTZ := (rT TZ).
+ Local Notation RTW := (RT TW).
+ Local Notation rTW := (rT TW).
+
+ Record SynthesisOutput :=
{
- zero : Expr RT;
- one : Expr RT;
- add : Expr (rT * rT -> rT); (* does not include carry *)
- sub : Expr (rT * rT -> rT); (* does not include carry *)
- mul : Expr (rT * rT -> rT); (* includes carry *)
- square : Expr (rT -> rT); (* includes carry *)
- opp : Expr (rT -> rT); (* does not include carry *)
- carry : Expr (rT -> rT);
- carry_add : Expr (rT * rT -> rT)
- := (carry ∘ add)%expr;
- carry_sub : Expr (rT * rT -> rT)
- := (carry ∘ sub)%expr;
- carry_opp : Expr (rT -> rT)
- := (carry ∘ opp)%expr;
-
- P : T' -> Prop;
-
- encode_valid : forall v, _;
- encode_sig := fun v => exist P (Interp (encode v) tt) (encode_valid v);
- encode_correct : forall v, decode (Interp (encode v) tt) = v;
-
- decode_sig := fun v : sig P => decode (proj1_sig v);
-
- zero_correct : zero = encode 0%F; (* which equality to use here? *)
- one_correct : one = encode 1%F; (* which equality to use here? *)
- zero_sig := encode_sig 0%F;
- one_sig := encode_sig 1%F;
-
- opp_valid : forall x, P x -> P (Interp carry_opp x);
- opp_sig := fun x => exist P _ (@opp_valid (proj1_sig x) (proj2_sig x));
-
- add_valid : forall x y, P x -> P y -> P (Interp carry_add (x, y));
- add_sig := fun x y => exist P _ (@add_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
-
- sub_valid : forall x y, P x -> P y -> P (Interp carry_sub (x, y));
- sub_sig := fun x y => exist P _ (@sub_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
-
- mul_valid : forall x y, P x -> P y -> P (Interp mul (x, y));
- mul_sig := fun x y => exist P _ (@mul_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
-
- square_correct : forall x, P x -> Interp square x = Interp mul (x, x);
-
- T := { v : T' | P v };
- eqT : T -> T -> Prop
- := fun x y => eq (decode (proj1_sig x)) (decode (proj1_sig y));
- ring : @Hierarchy.ring
- T eqT zero_sig one_sig opp_sig add_sig sub_sig mul_sig;
- encode_homomorphism
+ zeroZ : Expr RTZ;
+ oneZ : Expr RTZ;
+ addZ : Expr (rTZ * rTZ -> rTZ); (* does not include carry *)
+ subZ : Expr (rTZ * rTZ -> rTZ); (* does not include carry *)
+ carry_mulZ : Expr (rTZ * rTZ -> rTZ); (* includes carry *)
+ carry_squareZ : Expr (rTZ -> rTZ); (* includes carry *)
+ oppZ : Expr (rTZ -> rTZ); (* does not include carry *)
+ carryZ : Expr (rTZ -> rTZ);
+ carry_addZ : Expr (rTZ * rTZ -> rTZ)
+ := (carryZ ∘ addZ)%expr;
+ carry_subZ : Expr (rTZ * rTZ -> rTZ)
+ := (carryZ ∘ subZ)%expr;
+ carry_oppZ : Expr (rTZ -> rTZ)
+ := (carryZ ∘ oppZ)%expr;
+
+ zeroW : Expr RTW;
+ oneW : Expr RTW;
+ carry_addW : Expr (rTW * rTW -> rTW); (* includes carry *)
+ carry_subW : Expr (rTW * rTW -> rTW); (* includes carry *)
+ carry_mulW : Expr (rTW * rTW -> rTW); (* includes carry *)
+ carry_squareW : Expr (rTW -> rTW); (* includes carry *)
+ carry_oppW : Expr (rTW -> rTW); (* does not include carry *)
+
+ PZ : T' TZ -> Prop;
+ PW : T' TW -> Prop
+ := fun v => PZ (tuple_map (A:=Tbase TW) (B:=Tbase TZ) (SmartVarfMap (@interpToZ)) v);
+
+ encodeZ_valid : forall v, _;
+ encodeZ_sig := fun v => exist PZ (Interp (encode TZ v) tt) (encodeZ_valid v);
+ encodeZ_correct : forall v, decode TZ (Interp (encode TZ v) tt) = v;
+
+ decodeZ_sig := fun v : sig PZ => decode TZ (proj1_sig v);
+
+ zeroZ_correct : zeroZ = encode _ 0%F; (* which equality to use here? *)
+ oneZ_correct : oneZ = encode _ 1%F; (* which equality to use here? *)
+ zeroZ_sig := encodeZ_sig 0%F;
+ oneZ_sig := encodeZ_sig 1%F;
+
+ oppZ_valid : forall x, PZ x -> PZ (Interp carry_oppZ x);
+ oppZ_sig := fun x => exist PZ _ (@oppZ_valid (proj1_sig x) (proj2_sig x));
+
+ addZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_addZ (x, y));
+ addZ_sig := fun x y => exist PZ _ (@addZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
+
+ subZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_subZ (x, y));
+ subZ_sig := fun x y => exist PZ _ (@subZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
+
+ mulZ_valid : forall x y, PZ x -> PZ y -> PZ (Interp carry_mulZ (x, y));
+ mulZ_sig := fun x y => exist PZ _ (@mulZ_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
+
+ squareZ_correct : forall x, PZ x -> Interp carry_squareZ x = Interp carry_mulZ (x, x);
+
+
+ encodeW_valid : forall v, _;
+ encodeW_sig := fun v => exist PW (Interp (encode TW v) tt) (encodeW_valid v);
+ encodeW_correct : forall v, decode TW (Interp (encode TW v) tt) = v;
+
+ decodeW_sig := fun v : sig PW => decode TW (proj1_sig v);
+
+ zeroW_correct : zeroW = encode _ 0%F; (* which equality to use here? *)
+ oneW_correct : oneW = encode _ 1%F; (* which equality to use here? *)
+ zeroW_sig := encodeW_sig 0%F;
+ oneW_sig := encodeW_sig 1%F;
+
+ oppW_valid : forall x, PW x -> PW (Interp carry_oppW x);
+ oppW_sig := fun x => exist PW _ (@oppW_valid (proj1_sig x) (proj2_sig x));
+
+ addW_valid : forall x y, PW x -> PW y -> PW (Interp carry_addW (x, y));
+ addW_sig := fun x y => exist PW _ (@addW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
+
+ subW_valid : forall x y, PW x -> PW y -> PW (Interp carry_subW (x, y));
+ subW_sig := fun x y => exist PW _ (@subW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
+
+ mulW_valid : forall x y, PW x -> PW y -> PW (Interp carry_mulW (x, y));
+ mulW_sig := fun x y => exist PW _ (@mulW_valid (proj1_sig x) (proj1_sig y) (proj2_sig x) (proj2_sig y));
+
+ squareW_correct : forall x, PW x -> Interp carry_squareW x = Interp carry_mulW (x, x);
+
+ T_Z := { v : T' TZ | PZ v };
+ eqTZ : T_Z -> T_Z -> Prop
+ := fun x y => eq (decode _ (proj1_sig x)) (decode _ (proj1_sig y));
+ ringZ : @Hierarchy.ring
+ T_Z eqTZ zeroZ_sig oneZ_sig oppZ_sig addZ_sig subZ_sig mulZ_sig;
+ encodeZ_homomorphism
+ : @Ring.is_homomorphism
+ (F m) eq 1%F F.add F.mul
+ T_Z eqTZ oneZ_sig addZ_sig mulZ_sig
+ encodeZ_sig;
+ decodeZ_homomorphism
: @Ring.is_homomorphism
+ T_Z eqTZ oneZ_sig addZ_sig mulZ_sig
(F m) eq 1%F F.add F.mul
- T eqT one_sig add_sig mul_sig
- encode_sig;
- decode_homomorphism
+ decodeZ_sig;
+
+ T_W := { v : T' TW | PW v };
+ eqTW : T_W -> T_W -> Prop
+ := fun x y => eq (decode _ (proj1_sig x)) (decode _ (proj1_sig y));
+ ringW : @Hierarchy.ring
+ T_W eqTW zeroW_sig oneW_sig oppW_sig addW_sig subW_sig mulW_sig;
+ encodeW_homomorphism
: @Ring.is_homomorphism
- T eqT one_sig add_sig mul_sig
(F m) eq 1%F F.add F.mul
- decode_sig
+ T_W eqTW oneW_sig addW_sig mulW_sig
+ encodeW_sig;
+ decodeW_homomorphism
+ : @Ring.is_homomorphism
+ T_W eqTW oneW_sig addW_sig mulW_sig
+ (F m) eq 1%F F.add F.mul
+ decodeW_sig
}.
End gen.
-
-
-Record SynthesisOutput (curve : RawCurveParameters.CurveParameters) :=
- {
- onZ : SynthesisOutputOn curve TZ;
- onWord : SynthesisOutputOn curve (TWord (Z.log2_up curve.(bitwidth)))
- }.