diff options
author | Jason Gross <jgross@mit.edu> | 2018-03-18 21:54:26 -0400 |
---|---|---|
committer | Jason Gross <jasongross9@gmail.com> | 2018-03-19 14:17:26 -0400 |
commit | 2ee5a1b54d1fe45f621e0f77f3446e348e4c1d19 (patch) | |
tree | 01764af2a901e0bf0a3260d79d41492c39389f5a /src | |
parent | 9a35ebe478cb3e621a7a4eabf4d88d007cc7128e (diff) |
Add support for Z*Z casts, get montred working
Diffstat (limited to 'src')
-rw-r--r-- | src/Experiments/SimplyTypedArithmetic.v | 835 |
1 files changed, 484 insertions, 351 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 3b935f3df..1b58ede58 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -2021,6 +2021,7 @@ Module Compilers. | Z_zselect : ident (Z * Z * Z) Z | Z_add_modulo : ident (Z * Z * Z) Z | Z_cast (range : zrange) : ident Z Z + | Z_cast2 (range : zrange * zrange) : ident (Z * Z) (Z * Z) . Notation curry0 f @@ -2043,6 +2044,11 @@ Module Compilers. Section gen. Context (cast_outside_of_range : zrange -> BinInt.Z -> BinInt.Z). + Definition cast (r : zrange) (x : BinInt.Z) + := if (lower r <=? x) && (x <=? upper r) + then x + else cast_outside_of_range r x. + (** Interpret identifiers where the behavior of [Z_cast] on a value that does not fit in the range is given by a context variable. (This allows us to treat [Z_cast] @@ -2089,9 +2095,8 @@ Module Compilers. | Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow s) | Z_zselect => curry3 Z.zselect | Z_add_modulo => curry3 Z.add_modulo - | Z_cast r => fun x => if (lower r <=? x) && (x <=? upper r) - then x - else cast_outside_of_range r x + | Z_cast r => cast r + | Z_cast2 (r1, r2) => fun '(x1, x2) => (cast r1 x1, cast r2 x2) end. End gen. @@ -2229,6 +2234,7 @@ Module Compilers. Notation zselect := Z_zselect. Notation add_modulo := Z_add_modulo. Notation cast := Z_cast. + Notation cast2 := Z_cast2. End Z. Module Nat. @@ -2637,6 +2643,16 @@ Module Compilers. | None => None end. + Definition invert_Z_cast2 (e : @expr var (type.Z * type.Z)) : option ((zrange * zrange) * @expr var (type.Z * type.Z)) + := match invert_AppIdent e with + | Some (existT s (idc, args)) + => match idc in ident s t return expr s -> option ((zrange * zrange) * expr (type.Z * type.Z)) with + | ident.Z_cast2 r => fun v => Some (r, v) + | _ => fun _ => None + end args + | None => None + end. + Local Notation list_expr := (fun t => match t return Type with | type.list T => list (expr T) @@ -3156,6 +3172,7 @@ Module Compilers. | ident.Z_zselect as idc | ident.Z_add_modulo as idc | ident.Z_cast _ as idc + | ident.Z_cast2 _ as idc => cps_of (Uncurried.expr.default.ident.interp idc) | ident.Z_mul_split_concrete s => cps_of (curry2 (Z.mul_split s)) @@ -3371,6 +3388,7 @@ Module Compilers. (ident.snd @@ (Var xyk)) @ ((idc : default.ident _ (type.Z * type.Z)) @@ (ident.fst @@ (Var xyk))) + | ident.Z_cast2 _ as idc | ident.Z_mul_split_concrete _ as idc | ident.Z_add_get_carry_concrete _ as idc | ident.Z_sub_get_borrow_concrete _ as idc @@ -3895,14 +3913,14 @@ Module Compilers. | type.type_primitive x => primitive.option.interp x | type.prod A B => interp A * interp B | type.arrow s d => interp s -> interp d - | type.list A => list (interp A) + | type.list A => option (list (interp A)) end. Fixpoint None {t : type} : interp t := match t with | type.type_primitive x => @primitive.option.None x | type.prod A B => (@None A, @None B) | type.arrow s d => fun _ => @None d - | type.list A => @nil (interp A) + | type.list A => Datatypes.None end. Fixpoint Some {t : type} : type.interp t -> interp t := match t with @@ -3911,7 +3929,7 @@ Module Compilers. => fun x : type.interp A * type.interp B => (@Some A (fst x), @Some B (snd x)) | type.arrow s d => fun _ _ => @None d - | type.list A => List.map (@Some A) + | type.list A => fun ls => Datatypes.Some (List.map (@Some A) ls) end. Fixpoint is_tighter_than {t} : interp t -> interp t -> bool := match t with @@ -3922,7 +3940,13 @@ Module Compilers. => @is_tighter_than A ra ra' && @is_tighter_than B rb rb' | type.arrow s d => fun _ _ => false | type.list A - => fold_andb_map (@is_tighter_than A) + => fun ls1 ls2 + => match ls1, ls2 with + | Datatypes.None, Datatypes.None => true + | Datatypes.Some _, Datatypes.None => true + | Datatypes.None, Datatypes.Some _ => false + | Datatypes.Some ls1, Datatypes.Some ls2 => fold_andb_map (@is_tighter_than A) ls1 ls2 + end end. Fixpoint is_bounded_by {t} : interp t -> Compilers.type.interp t -> bool := match t return interp t -> Compilers.type.interp t -> bool with @@ -3933,7 +3957,11 @@ Module Compilers. => @is_bounded_by A ra ra' && @is_bounded_by B rb rb' | type.arrow s d => fun _ _ => false | type.list A - => fold_andb_map (@is_bounded_by A) + => fun ls1 ls2 + => match ls1 with + | Datatypes.None => false + | Datatypes.Some ls1 => fold_andb_map (@is_bounded_by A) ls1 ls2 + end end. Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val @@ -3954,7 +3982,10 @@ Module Compilers. | apply conj | Z.ltb_to_lt; omega | rewrite @fold_andb_map_map in * ]. - { revert r1 r2 val Htight Hbounds IHt. + { lazymatch goal with + | [ r1 : list (interp t), r2 : list (type.interp t), val : list (Compilers.type.interp t) |- _ ] + => revert r1 r2 val Htight Hbounds IHt + end; intros r1 r2 val; revert r1 r2 val. induction r1, r2, val; cbn; auto with nocore; try congruence; []. rewrite !Bool.andb_true_iff; intros; destruct_head'_and; split; eauto with nocore. } Qed. @@ -4000,12 +4031,18 @@ Module Compilers. | ident.Z_sub_get_borrow | ident.Z_modulo => fun _ => type.option.None - | ident.nil t => curry0 (@nil (type.option.interp t)) - | ident.cons t => curry2 (@Datatypes.cons (type.option.interp t)) + | ident.nil t => curry0 (Some (@nil (type.option.interp t))) + | ident.cons t => curry2 (fun a => option_map (@Datatypes.cons (type.option.interp t) a)) | ident.fst A B => @Datatypes.fst (type.option.interp A) (type.option.interp B) | ident.snd A B => @Datatypes.snd (type.option.interp A) (type.option.interp B) | ident.List_nth_default_concrete T d n - => fun ls => @nth_default (type.option.interp T) type.option.None ls n + => fun ls + => match ls with + | Datatypes.Some ls + => @nth_default (type.option.interp T) type.option.None ls n + | Datatypes.None + => type.option.None + end | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc | ident.Z_opp as idc @@ -4029,6 +4066,16 @@ Module Compilers. | Some r => ZRange.intersection r range | None => range end + | ident.Z_cast2 (r1, r2) + => fun '((r1', r2') : option zrange * option zrange) + => (Some match r1' with + | Some r => ZRange.intersection r r1 + | None => r1 + end, + Some match r2' with + | Some r => ZRange.intersection r r2 + | None => r2 + end) | ident.Z_mul_split_concrete split_at => fun '((x, y) : option zrange * option zrange) => match x, y with @@ -4110,7 +4157,7 @@ Module Compilers. End DefaultValue. Module partial. - Notation Zdata := (option zrange). + Notation data := ZRange.type.option.interp. Section value. Context (var : type -> Type). Definition value_prestep (value : type -> Type) (t : type) @@ -4125,12 +4172,10 @@ Module Compilers. := match t return Type with | type.arrow _ _ as t => value_prestep value t - | type.type_primitive type.Z as t - => Zdata * @expr var t + value_prestep value t | type.prod _ _ as t | type.list _ as t | type.type_primitive _ as t - => @expr var t + value_prestep value t + => data t * @expr var t + value_prestep value t end%type. Fixpoint value (t : type) := value_step value t. @@ -4145,6 +4190,38 @@ Module Compilers. | type.arrow s d => fun _ => @value_default d | type.list A => inr (@nil (value A)) end. + + Fixpoint data_from_value {t} : value t -> data t + := match t return value t -> data t with + | type.arrow _ _ as t + => fun _ => ZRange.type.option.None + | type.prod A B as t + => fun v + => match v with + | inl (data, _) => data + | inr (a, b) + => (@data_from_value A a, @data_from_value B b) + end + | type.list A as t + => fun v + => match v with + | inl (data, _) => data + | inr ls + => Some (List.map (@data_from_value A) ls) + end + | type.type_primitive type.Z as t + => fun v + => match v with + | inl (data, _) => data + | inr v => Some r[v~>v]%zrange + end + | type.type_primitive _ as t + => fun v + => match v with + | inl (data, _) => data + | inr _ => ZRange.type.option.None + end + end. End value. Module expr. @@ -4154,9 +4231,19 @@ Module Compilers. : value var t -> @expr var t := match t return value var t -> expr t with | type.prod A B as t - => fun x : expr t + value var A * value var B + => fun x : (data A * data B) * expr t + value var A * value var B => match x with - | inl v => v + | inl ((da, db), v) + => match A, B return data A -> data B -> expr (A * B) -> expr (A * B) with + | type.Z, type.Z + => fun da db v + => match da, db with + | Some r1, Some r2 + => (ident.Z.cast2 (r1, r2)%core @@ v)%expr + | _, _ => v + end + | _, _ => fun _ _ v => v + end da db v | inr (a, b) => (@reify A a, @reify B b)%expr end | type.arrow s d @@ -4164,24 +4251,24 @@ Module Compilers. => Abs (fun x => @reify d (f (@reflect s (Var x)))) | type.list A as t - => fun x : expr t + list (value var A) + => fun x : _ * expr t + list (value var A) => match x with - | inl v => v + | inl (_, v) => v | inr v => reify_list (List.map (@reify A) v) end | type.type_primitive type.Z as t - => fun x : Zdata * expr t + type.interp t + => fun x : _ * expr t + type.interp t => match x with | inl (Some r, v) => ident.Z.cast r @@ v | inl (None, v) => v | inr v => ident.primitive v @@ TT end%core%expr | type.type_primitive _ as t - => fun x : expr t + type.interp t + => fun x : _ * expr t + type.interp t => match x with - | inl v => v + | inl (_, v) => v | inr v => ident.primitive v @@ TT - end%expr + end%core%expr end with reflect {t : type} : @expr var t -> value var t @@ -4191,28 +4278,39 @@ Module Compilers. => @reflect d (App f (@reify s x)) | type.prod A B as t => fun v : expr t - => let inr := @inr (expr t) (value_prestep (value var) t) in - let inl := @inl (expr t) (value_prestep (value var) t) in + => let inr := @inr (data t * expr t) (value_prestep (value var) t) in + let inl := @inl (data t * expr t) (value_prestep (value var) t) in match invert_Pair v with | Some (a, b) => inr (@reflect A a, @reflect B b) | None - => inl v + => inl + (match A, B return expr (A * B) -> data (A * B) * expr (A * B) with + | type.Z, type.Z + => fun v + => match invert_Z_cast2 v with + | Some (r, v) + => (ZRange.type.option.Some (t:=type.Z*type.Z) r, v) + | None + => (ZRange.type.option.None, v) + end + | _, _ => fun v => (ZRange.type.option.None, v) + end v) end | type.list A as t => fun v : expr t - => let inr := @inr (expr t) (value_prestep (value var) t) in - let inl := @inl (expr t) (value_prestep (value var) t) in + => let inr := @inr (data t * expr t) (value_prestep (value var) t) in + let inl := @inl (data t * expr t) (value_prestep (value var) t) in match reflect_list v with | Some ls => inr (List.map (@reflect A) ls) | None - => inl v + => inl (None, v) end | type.type_primitive type.Z as t => fun v : expr t - => let inr' := @inr (Zdata * expr t) (value_prestep (value var) t) in - let inl' := @inl (Zdata * expr t) (value_prestep (value var) t) in + => let inr' := @inr (data t * expr t) (value_prestep (value var) t) in + let inl' := @inl (data t * expr t) (value_prestep (value var) t) in match reflect_primitive v, invert_Z_cast v with | Some v, _ => inr' v | None, Some (r, v) => inl' (Some r, v) @@ -4220,11 +4318,11 @@ Module Compilers. end | type.type_primitive _ as t => fun v : expr t - => let inr := @inr (expr t) (value_prestep (value var) t) in - let inl := @inl (expr t) (value_prestep (value var) t) in + => let inr := @inr (data t * expr t) (value_prestep (value var) t) in + let inl := @inl (data t * expr t) (value_prestep (value var) t) in match reflect_primitive v with | Some v => inr v - | None => inl v + | None => inl (tt, v) end end. End reify. @@ -4239,6 +4337,8 @@ Module Compilers. | TT => true | AppIdent _ _ (ident.fst _ _) args => @is_var_like _ args | AppIdent _ _ (ident.snd _ _) args => @is_var_like _ args + | AppIdent _ _ (ident.Z.cast _) args => @is_var_like _ args + | AppIdent _ _ (ident.Z.cast2 _) args => @is_var_like _ args | Pair A B a b => @is_var_like A a && @is_var_like B b | AppIdent _ _ _ _ => false | App _ _ _ _ @@ -4252,7 +4352,7 @@ Module Compilers. | type.arrow _ _ => fun x f => f x | type.list T as t - => fun (x : expr t + list (value var T)) (f : expr t + list (value var T) -> value var tC) + => fun (x : data t * expr t + list (value var T)) (f : data t * expr t + list (value var T) -> value var tC) => match x with | inr ls => list_rect @@ -4267,7 +4367,7 @@ Module Compilers. | inl e => f (inl e) end | type.prod A B as t - => fun (x : expr t + value var A * value var B) (f : expr t + value var A * value var B -> value var tC) + => fun (x : data t * expr t + value var A * value var B) (f : data t * expr t + value var A * value var B -> value var tC) => match x with | inr (a, b) => @interp_let_in @@ -4276,11 +4376,12 @@ Module Compilers. => @interp_let_in _ B b (fun b => f (inr (a, b)))) - | inl e => partial.expr.reflect (expr_let y := e in partial.expr.reify (f (inl (Var y))))%expr + | inl (data, e) => partial.expr.reflect + (expr_let y := partial.expr.reify (t:=t) x in + partial.expr.reify (f (inl (data, Var y)%core)))%expr end - | type.type_primitive type.Z as t - => fun (x : Zdata * expr t + type.interp t) - (f : Zdata * expr t + type.interp t -> value var tC) + | type.type_primitive _ as t + => fun (x : data t * expr t + type.interp t) (f : data t * expr t + type.interp t -> value var tC) => match x with | inl (data, e) => if is_var_like e @@ -4288,74 +4389,73 @@ Module Compilers. else partial.expr.reflect (expr_let y := (partial.expr.reify (t:=t) x) in partial.expr.reify (f (inl (data, Var y)%core)))%expr - | inr v => f (inr v) - end - | type.type_primitive _ as t - => fun (x : expr t + type.interp t) (f : expr t + type.interp t -> value var tC) - => match x with - | inl e - => match invert_Var e with - | Some _ => f x - | None => partial.expr.reflect - (expr_let y := (partial.expr.reify (t:=t) x) in - partial.expr.reify (f (inl (Var y))))%expr - end | inr v => f (inr v) (* FIXME: do not substitute [S (big stuck term)] *) end end. + Let default_interp + {s d} + : ident s d -> value var s -> value var d + := match d return ident s d -> value var s -> value var d with + | type.arrow _ _ + => fun idc args => expr.reflect (AppIdent idc (expr.reify args)) + | _ + => fun idc args + => inl (ZRange.ident.option.interp idc (data_from_value var args), + AppIdent idc (expr.reify args)) + end. + (** do partial reduction on identifiers *) Definition interp {s d} (idc : ident s d) : value var (s -> d) := match idc in ident s d return value var (s -> d) with | ident.Let_In tx tC as idc - => fun (xf : expr (tx * (tx -> tC)) + value var tx * value var (tx -> tC)) + => fun (xf : data (tx * (tx -> tC)) * expr (tx * (tx -> tC)) + value var tx * value var (tx -> tC)) => match xf with | inr (x, f) => interp_let_in x f | _ => expr.reflect (AppIdent idc (expr.reify (t:=tx * (tx -> tC)) xf)) end | ident.nil t => fun _ => inr (@nil (value var t)) - | ident.primitive type.Z v - => fun _ => inr v | ident.primitive t v => fun _ => inr v | ident.cons t as idc - => fun (x_xs : expr (t * type.list t) + value var t * (expr (type.list t) + list (value var t))) - => match x_xs return expr (type.list t) + list (value var t) with + => fun (x_xs : data (t * type.list t) * expr (t * type.list t) + value var t * (data (type.list t) * expr (type.list t) + list (value var t))) + => match x_xs return data (type.list t) * expr (type.list t) + list (value var t) with | inr (x, inr xs) => inr (cons x xs) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=t * type.list t) x_xs)) + | _ + => default_interp idc x_xs end | ident.fst A B as idc - => fun x : expr (A * B) + value var A * value var B + => fun x : data (A * B) * expr (A * B) + value var A * value var B => match x with | inr x => fst x - | _ => expr.reflect (AppIdent idc (expr.reify (t:=A*B) x)) + | _ => default_interp idc x end | ident.snd A B as idc - => fun x : expr (A * B) + value var A * value var B + => fun x : data (A * B) * expr (A * B) + value var A * value var B => match x with | inr x => snd x - | _ => expr.reflect (AppIdent idc (expr.reify (t:=A*B) x)) + | _ => default_interp idc x end | ident.bool_rect T as idc - => fun (true_case_false_case_b : expr (T * T * type.bool) + (expr (T * T) + value var T * value var T) * (expr type.bool + bool)) + => fun (true_case_false_case_b : data (T * T * type.bool) * expr (T * T * type.bool) + (data (T * T) * expr (T * T) + value var T * value var T) * (data type.bool * expr type.bool + bool)) => match true_case_false_case_b with | inr (inr (true_case, false_case), inr b) => @bool_rect (fun _ => value var T) true_case false_case b - | _ => expr.reflect (AppIdent idc (expr.reify (t:=T*T*type.bool) true_case_false_case_b)) + | _ => default_interp idc true_case_false_case_b end | ident.nat_rect P as idc - => fun (O_case_S_case_n : expr (P * (type.nat * P -> P) * type.nat) + (expr (P * (type.nat * P -> P)) + value var P * value var (type.nat * P -> P)) * (expr type.nat + nat)) + => fun (O_case_S_case_n : _ * expr (P * (type.nat * P -> P) * type.nat) + (_ * expr (P * (type.nat * P -> P)) + value var P * value var (type.nat * P -> P)) * (_ * expr type.nat + nat)) => match O_case_S_case_n with | inr (inr (O_case, S_case), inr n) => @nat_rect (fun _ => value var P) O_case (fun n' rec => S_case (inr (inr n', rec))) n - | _ => expr.reflect (AppIdent idc (expr.reify (t:=P * (type.nat * P -> P) * type.nat) O_case_S_case_n)) + | _ => default_interp idc O_case_S_case_n end | ident.list_rect A P as idc - => fun (nil_case_cons_case_ls : expr (P * (A * type.list A * P -> P) * type.list A) + (expr (P * (A * type.list A * P -> P)) + value var P * value var (A * type.list A * P -> P)) * (expr (type.list A) + list (value var A))) + => fun (nil_case_cons_case_ls : _ * expr (P * (A * type.list A * P -> P) * type.list A) + (_ * expr (P * (A * type.list A * P -> P)) + value var P * value var (A * type.list A * P -> P)) * (_ * expr (type.list A) + list (value var A))) => match nil_case_cons_case_ls with | inr (inr (nil_case, cons_case), inr ls) => @list_rect @@ -4364,60 +4464,60 @@ Module Compilers. nil_case (fun x xs rec => cons_case (inr (inr (x, inr xs), rec))) ls - | _ => expr.reflect (AppIdent idc (expr.reify (t:=P * (A * type.list A * P -> P) * type.list A) nil_case_cons_case_ls)) + | _ => default_interp idc nil_case_cons_case_ls end | ident.List.nth_default type.Z as idc - => fun (default_ls_idx : expr (type.Z * type.list type.Z * type.nat) + (expr (type.Z * type.list type.Z) + (_ * expr type.Z + type.interp type.Z) * (expr (type.list type.Z) + list (value var type.Z))) * (expr type.nat + nat)) + => fun (default_ls_idx : _ * expr (type.Z * type.list type.Z * type.nat) + (_ * expr (type.Z * type.list type.Z) + (_ * expr type.Z + type.interp type.Z) * (_ * expr (type.list type.Z) + list (value var type.Z))) * (_ * expr type.nat + nat)) => match default_ls_idx with | inr (inr (default, inr ls), inr idx) => List.nth_default default ls idx | inr (inr (inr default, ls), inr idx) - => expr.reflect (AppIdent (ident.List.nth_default_concrete default idx) (expr.reify (t:=type.list type.Z) ls)) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=type.Z * type.list type.Z * type.nat) default_ls_idx)) + => default_interp (ident.List.nth_default_concrete default idx) ls + | _ => default_interp idc default_ls_idx end | ident.List.nth_default (type.type_primitive A) as idc - => fun (default_ls_idx : expr (A * type.list A * type.nat) + (expr (A * type.list A) + (expr A + type.interp A) * (expr (type.list A) + list (value var A))) * (expr type.nat + nat)) + => fun (default_ls_idx : _ * expr (A * type.list A * type.nat) + (_ * expr (A * type.list A) + (_ * expr A + type.interp A) * (_ * expr (type.list A) + list (value var A))) * (_ * expr type.nat + nat)) => match default_ls_idx with | inr (inr (default, inr ls), inr idx) => List.nth_default default ls idx | inr (inr (inr default, ls), inr idx) - => expr.reflect (AppIdent (ident.List.nth_default_concrete default idx) (expr.reify (t:=type.list A) ls)) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=A * type.list A * type.nat) default_ls_idx)) + => default_interp (ident.List.nth_default_concrete default idx) ls + | _ => default_interp idc default_ls_idx end | ident.List.nth_default A as idc - => fun (default_ls_idx : expr (A * type.list A * type.nat) + (expr (A * type.list A) + value var A * (expr (type.list A) + list (value var A))) * (expr type.nat + nat)) + => fun (default_ls_idx : _ * expr (A * type.list A * type.nat) + (_ * expr (A * type.list A) + value var A * (_ * expr (type.list A) + list (value var A))) * (_ * expr type.nat + nat)) => match default_ls_idx with | inr (inr (default, inr ls), inr idx) => List.nth_default default ls idx - | _ => expr.reflect (AppIdent idc (expr.reify (t:=A * type.list A * type.nat) default_ls_idx)) + | _ => default_interp idc default_ls_idx end | ident.List.nth_default_concrete A default idx as idc - => fun (ls : expr (type.list A) + list (value var A)) + => fun (ls : _ * expr (type.list A) + list (value var A)) => match ls with | inr ls => List.nth_default (expr.reflect (t:=A) (AppIdent (ident.primitive default) TT)) ls idx - | _ => expr.reflect (AppIdent idc (expr.reify (t:=type.list A) ls)) + | _ => default_interp idc ls end | ident.Z_mul_split as idc - => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + + (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) + => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) | inr (inr (inr x, y), z) - => expr.reflect (AppIdent (ident.Z.mul_split_concrete x) (expr.reify (t:=type.Z*type.Z) (inr (y, z)))) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z)) + => default_interp (ident.Z.mul_split_concrete x) (inr (y, z)) + | _ => default_interp idc x_y_z end | ident.Z_add_get_carry as idc - => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + + (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) + => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) | inr (inr (inr x, y), z) - => let default := expr.reflect (AppIdent (ident.Z.add_get_carry_concrete x) (expr.reify (t:=type.Z*type.Z) (inr (y, z)))) in + => let default := default_interp (ident.Z.add_get_carry_concrete x) (inr (y, z)) in match (y, z) with | (inr xx, inl e) | (inl e, inr xx) @@ -4426,45 +4526,45 @@ Module Compilers. else default | _ => default end - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z)) + | _ => default_interp idc x_y_z end | ident.Z_add_with_get_carry as idc - => fun (x_y_z_a : (expr (_ * _ * _ * _) + - (expr (_ * _ * _) + - (expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * + => fun (x_y_z_a : (_ * expr (_ * _ * _ * _) + + (_ * expr (_ * _ * _) + + (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) - => match x_y_z_a return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => match x_y_z_a return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr (inr (inr x, inr y), inr z), inr a) => let result := ident.interp idc (x, y, z, a) in inr (inr (fst result), inr (snd result)) | inr (inr (inr (inr x, y), z), a) - => expr.reflect (AppIdent (ident.Z.add_with_get_carry_concrete x) (expr.reify (t:=type.Z*type.Z*type.Z) (inr (inr (y, z), a)))) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_*_) x_y_z_a)) + => default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a)) + | _ => default_interp idc x_y_z_a end | ident.Z_sub_get_borrow as idc - => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + + (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) + => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) | inr (inr (inr x, y), z) - => expr.reflect (AppIdent (ident.Z.sub_get_borrow_concrete x) (expr.reify (t:=type.Z*type.Z) (inr (y, z)))) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z)) + => default_interp (ident.Z.sub_get_borrow_concrete x) (inr (y, z)) + | _ => default_interp idc x_y_z end | ident.Z_mul_split_concrete _ as idc | ident.Z.sub_get_borrow_concrete _ as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => match x_y return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr x, inr y) => let result := ident.interp idc (x, y) in inr (inr (fst result), inr (snd result)) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) + | _ => default_interp idc x_y end | ident.Z.add_get_carry_concrete _ as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in - match x_y return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => let default := default_interp idc x_y in + match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr x, inr y) => let result := ident.interp idc (x, y) in inr (inr (fst result), inr (snd result)) @@ -4476,26 +4576,26 @@ Module Compilers. | _ => default end | ident.Z.add_with_get_carry_concrete _ as idc - => fun (x_y_z : (expr (type.Z * type.Z * type.Z) + - (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) - => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with + => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) + + (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type) + => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with | inr (inr (inr x, inr y), inr z) => let result := ident.interp idc (x, y, z) in inr (inr (fst result), inr (snd result)) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z)) + | _ => default_interp idc x_y_z end | ident.pred as idc | ident.Nat_succ as idc - => fun x : expr _ + type.interp _ - => match x return expr _ + type.interp _ with + => fun x : _ * expr _ + type.interp _ + => match x return _ * expr _ + type.interp _ with | inr x => inr (ident.interp idc x) - | inl x => expr.reflect (AppIdent idc x) + | _ => default_interp idc x end | ident.Z_of_nat as idc - => fun x : expr _ + type.interp _ + => fun x : _ * expr _ + type.interp _ => match x return _ * expr _ + type.interp _ with | inr x => inr (ident.interp idc x) - | inl x => expr.reflect (AppIdent idc x) + | _ => default_interp idc x end | ident.Z_opp as idc => fun x : _ * expr _ + type.interp _ @@ -4522,36 +4622,36 @@ Module Compilers. | ident.Z_eqb as idc | ident.Z_leb as idc | ident.Z_pow as idc - => fun (x_y : expr (_ * _) + (_ + type.interp _) * (_ + type.interp _)) + => fun (x_y : data (_ * _) * expr (_ * _) + (_ + type.interp _) * (_ + type.interp _)) => match x_y return _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) + | _ => default_interp idc x_y end | ident.Z_div as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => let default := default_interp idc x_y in match x_y return _ * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (x, inr y) => if Z.eqb y (2^Z.log2 y) - then expr.reflect (AppIdent (ident.Z.shiftr (Z.log2 y)) (expr.reify (t:=type.Z) x)) + then default_interp (ident.Z.shiftr (Z.log2 y)) x else default | _ => default end | ident.Z_modulo as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => let default := default_interp idc x_y in match x_y return _ * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (x, inr y) => if Z.eqb y (2^Z.log2 y) - then expr.reflect (AppIdent (ident.Z.land (y-1)) (expr.reify (t:=type.Z) x)) + then default_interp (ident.Z.land (y-1)) x else default | _ => default end | ident.Z_mul as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) - => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => let default := default_interp idc x_y in match x_y return _ * expr _ + type.interp _ with | inr (inr x, inr y) => inr (ident.interp idc (x, y)) | inr (inr x, inl (data, e) as y) @@ -4574,7 +4674,7 @@ Module Compilers. | inl _ => default end | ident.Z_add as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) => let default0 := AppIdent idc (expr.reify (t:=_*_) x_y) in let default := expr.reflect default0 in match x_y return _ * expr _ + type.interp _ with @@ -4608,7 +4708,7 @@ Module Compilers. | inl _ => default end | ident.Z_sub as idc - => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) => let default0 := AppIdent idc (expr.reify (t:=_*_) x_y) in let default := expr.reflect default0 in match x_y return _ * expr _ + type.interp _ with @@ -4641,11 +4741,11 @@ Module Compilers. end | ident.Z_zselect as idc | ident.Z_add_modulo as idc - => fun (x_y_z : (expr (_ * _ * _) + - (expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) + => fun (x_y_z : (_ * expr (_ * _ * _) + + (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type) => match x_y_z return _ * expr _ + type.interp _ with | inr (inr (inr x, inr y), inr z) => inr (ident.interp idc (x, y, z)) - | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z)) + | _ => default_interp idc x_y_z end | ident.Z_cast r as idc => fun (x : _ * expr _ + type.interp _) @@ -4654,6 +4754,24 @@ Module Compilers. | inl (data, e) => inl (ZRange.ident.option.interp idc data, e) end + | ident.Z_cast2 (r1, r2) as idc + => fun (x : _ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) + => match x with + | inr (inr a, inr b) + => inr (inr (ident.interp (ident.Z.cast r1) a), + inr (ident.interp (ident.Z.cast r2) b)) + | inr (inr a, inl (r2', b)) + => inr (inr (ident.interp (ident.Z.cast r1) a), + inl (ZRange.ident.option.interp (ident.Z.cast r2) r2', b)) + | inr (inl (r1', a), inr b) + => inr (inl (ZRange.ident.option.interp (ident.Z.cast r1) r1', a), + inr (ident.interp (ident.Z.cast r2) b)) + | inr (inl (r1', a), inl (r2', b)) + => inr (inl (ZRange.ident.option.interp (ident.Z.cast r1) r1', a), + inl (ZRange.ident.option.interp (ident.Z.cast r2) r2', b)) + | inl (data, e) + => inl (ZRange.ident.option.interp idc data, e) + end end. End interp. End ident. @@ -4712,37 +4830,74 @@ Module Compilers. | type.type_primitive t => fun _ => id | type.prod A B => fun '((ra, rb) : ZRange.type.option.interp A * ZRange.type.option.interp B) - (e : expr _ + partial.value var A * partial.value var B) + (e : _ * expr _ + partial.value var A * partial.value var B) => match e with | inr (a, b) => inr (@extend_with_obounds A ra a, @extend_with_obounds B rb b) - | inl e + | inl ((dataa, datab), e) => if partial.ident.is_var_like e then inr (@extend_with_obounds A ra (partial.expr.reflect (AppIdent ident.fst e)), @extend_with_obounds B rb (partial.expr.reflect (AppIdent ident.snd e))) - else inl e + else inl + (match A, B return ZRange.type.option.interp A -> ZRange.type.option.interp B -> data A -> data B -> expr (A * B) -> data (A * B) * expr (A * B) with + | type.Z, type.Z + => fun ra rb da db e + => let da' + := match ra with + | Some ra + => ZRange.ident.option.interp + (ident.Z.cast ra) da + | None => da + end in + let db' + := match rb with + | Some rb + => ZRange.ident.option.interp + (ident.Z.cast rb) db + | None => db + end in + ((da', db'), e) + | _, _ + => fun _ _ da db e => ((da, db), e) + end ra rb dataa datab e) end | type.arrow s d => fun _ => id | type.list A - => fun (ls : Datatypes.list (ZRange.type.option.interp A)) - (e : expr _ + list (partial.value var A)) - => match e with - | inl e - => match A return (ZRange.type.option.interp A -> partial.value var A -> partial.value var A) - -> Datatypes.list (ZRange.type.option.interp A) - -> expr (type.list A) - -> partial.value var (type.list A) - with - | type.type_primitive A - => fun extend_with_obounds ls e - => inr (extend_list_expr_with_obounds - extend_with_obounds 0 ls e) - | A' - => fun _ _ e => inl e - end (@extend_with_obounds A) ls e - | inr e => inr (extend_concrete_list_with_obounds - (@extend_with_obounds A) ls e) + => fun (ls : option (Datatypes.list (ZRange.type.option.interp A))) + (e : data _ * expr _ + list (partial.value var A)) + => match ls with + | None => e + | Some ls + => + match e with + | inl (data, e) + => match A return (ZRange.type.option.interp A -> partial.value var A -> partial.value var A) + -> Datatypes.list (ZRange.type.option.interp A) + -> option (Datatypes.list (ZRange.type.option.interp A)) + -> expr (type.list A) + -> partial.value var (type.list A) + with + | type.type_primitive A + => fun extend_with_obounds ls data e + => match data with + | Some data + => inr + (extend_concrete_list_with_obounds + extend_with_obounds ls + (extend_list_expr_with_obounds + extend_with_obounds 0 data e)) + | None + => inr (extend_list_expr_with_obounds + extend_with_obounds 0 ls e) + end + | A' + (* N.B. We clobber the existing bounds here, rather than fusing them *) + => fun _ ls data e => inl (Some ls, e) + end (@extend_with_obounds A) ls data e + | inr e => inr (extend_concrete_list_with_obounds + (@extend_with_obounds A) ls e) + end end end. Definition extend_with_bounds {t} @@ -4761,10 +4916,10 @@ Module Compilers. | ident.Z_cast range => fun _ => Some range | ident.primitive type.Z v => fun _ => Some r[v~>v]%zrange - | ident.nil _ => fun _ => nil + | ident.nil _ => fun _ => Some nil | ident.cons t - => fun '((x, xs) : ZRange.type.option.interp t * list (ZRange.type.option.interp t)) - => cons x xs + => fun '((x, xs) : ZRange.type.option.interp t * option (list (ZRange.type.option.interp t))) + => option_map (cons x) xs | _ => fun _ => ZRange.type.option.None end. End ident. @@ -4798,7 +4953,10 @@ Module Compilers. Section partial_reduce. Context {var : type -> Type}. - Fixpoint partial_reduce' {t} (e : @expr (partial.value var) t) + Definition partial_reduce'_step + (partial_reduce' : forall {t} (e : @expr (partial.value var) t), + partial.value var t) + {t} (e : @expr (partial.value var) t) : partial.value var t := match e in expr.expr t return partial.value var t with | Var t v => v @@ -4808,6 +4966,9 @@ Module Compilers. | App s d f x => @partial_reduce' _ f (@partial_reduce' _ x) | Abs s d f => fun x => @partial_reduce' d (f x) end. + Fixpoint partial_reduce' {t} (e : @expr (partial.value var) t) + : partial.value var t + := @partial_reduce'_step (@partial_reduce') t e. Definition partial_reduce {t} (e : @expr (partial.value var) t) : @expr var t := partial.expr.reify (@partial_reduce' t e). @@ -4840,6 +5001,12 @@ Module Compilers. | Some r => AppIdent (ident.Z.cast r) | None => id end + | ident.Z_cast2 (r1, r2) + => match relax_zrange r1, relax_zrange r2 with + | Some r1, Some r2 + => AppIdent (ident.Z.cast2 (r1, r2)) + | Some _, None | None, Some _ | None, None => id + end | idc => AppIdent idc end. End relax. @@ -6293,7 +6460,7 @@ Ltac solve_rone := solve_rop rone_correct. Module PrintingNotations. Export ident. - Global Set Printing Width 100000. + (*Global Set Printing Width 100000.*) Open Scope zrange_scope. Notation "'uint256'" := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : zrange_scope. @@ -6303,6 +6470,8 @@ Module PrintingNotations. := (r[0 ~> 18446744073709551615]) : zrange_scope. Notation "'uint32'" := (r[0 ~> 4294967295]) : zrange_scope. + Notation "'bool'" + := (r[0 ~> 1]%zrange) : zrange_scope. Notation "ls [[ n ]]" := ((List.nth_default_concrete _ n @@ ls)%expr) (at level 30, format "ls [[ n ]]") : expr_scope. @@ -6310,12 +6479,16 @@ Module PrintingNotations. := ((ident.Z.cast range @@ (List.nth_default_concrete _ n @@ ls))%expr) (format "( range )( ls [[ n ]] )") : expr_scope. (*Notation "( range )( v )" := (ident.Z.cast range @@ v)%expr : expr_scope.*) + Notation "x *₂₅₆ y" + := (ident.Z.cast uint256 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. Notation "x *₁₂₈ y" := (ident.Z.cast uint128 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. Notation "x *₆₄ y" := (ident.Z.cast uint64 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. Notation "x *₃₂ y" := (ident.Z.cast uint32 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope. + Notation "x +₂₅₆ y" + := (ident.Z.cast uint256 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope. Notation "x +₁₂₈ y" := (ident.Z.cast uint128 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope. Notation "x +₆₄ y" @@ -6341,8 +6514,17 @@ Module PrintingNotations. := ((ident.Z.cast out_t @@ (ident.Z.land mask @@ v))%expr) (format "( ( out_t )( v ) & mask )") : expr_scope. - Notation "v ₁" := (ident.fst @@ v)%expr (at level 10, format "v ₁") : expr_scope. - Notation "v ₂" := (ident.snd @@ v)%expr (at level 10, format "v ₂") : expr_scope. + + Notation "x" := (ident.Z.cast _ @@ Var x)%expr (only printing, at level 9) : expr_scope. + Notation "x" := (ident.Z.cast2 _ @@ Var x)%expr (only printing, at level 9) : expr_scope. + Notation "v ₁" := (ident.fst @@ Var v)%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (ident.snd @@ Var v)%expr (at level 10, format "v ₂") : expr_scope. + Notation "v ₁" := (ident.Z.cast _ @@ (ident.fst @@ Var v))%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (ident.Z.cast _ @@ (ident.snd @@ Var v))%expr (at level 10, format "v ₂") : expr_scope. + Notation "v ₁" := (ident.Z.cast _ @@ (ident.fst @@ (ident.Z.cast2 _ @@ Var v)))%expr (at level 10, format "v ₁") : expr_scope. + Notation "v ₂" := (ident.Z.cast _ @@ (ident.snd @@ (ident.Z.cast2 _ @@ Var v)))%expr (at level 10, format "v ₂") : expr_scope. + + (*Notation "ls [[ n ]]" := (List.nth_default_concrete _ n @@ ls)%expr : expr_scope. Notation "( range )( v )" := (ident.Z.cast range @@ v)%expr : expr_scope. Notation "x *₁₂₈ y" @@ -6408,13 +6590,12 @@ Module PrintingNotations. (* TODO: come up with a better notation for arithmetic with carries that still distinguishes it from arithmetic without carries? *) Local Notation "'TwoPow256'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 (only parsing). - (*Notation "'ADD_256'" := (add_get_carry_concrete _ _ uint256 _ TwoPow256) : nexpr_scope. - Notation "'ADD_128'" := (add_get_carry_concrete _ _ uint128 _ TwoPow256) : nexpr_scope. - Notation "'ADDC_256'" := (add_with_get_carry_concrete _ _ _ uint256 _ TwoPow256) : nexpr_scope. - Notation "'SUB_256'" := (sub_get_borrow_concrete _ _ uint256 _ TwoPow256) : nexpr_scope. - Notation "'ADDM'" := (add_modulo _ _ _ uint256) : nexpr_scope. - Notation "'SELC'" := (zselect _ _ _ uint256) : nexpr_scope. - Notation "'MUL_256'" := (mul uint128 uint128 uint256) : nexpr_scope.*) + Notation "'ADD_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_get_carry_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. + Notation "'ADD_128' ( x , y )" := (ident.Z.cast2 (uint128, bool)%core @@ (ident.Z.add_get_carry_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. + Notation "'ADDC_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope. + Notation "'SUB_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_get_borrow_concrete TwoPow256 @@ (x, y)))%expr : expr_scope. + Notation "'ADDM' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.add_modulo @@ (x, y, z)))%expr : expr_scope. + Notation "'SELC' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (x, y, z)))%expr : expr_scope. End PrintingNotations. (* @@ -7003,6 +7184,9 @@ Module MontgomeryReduction. = montred' N R N' w w_half n lo_hi) As montred_gen_correct. Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed. + Module Export ReifyHints. + Global Hint Extern 1 (_ = montred' _ _ _ _ _ _ _) => simple apply montred_gen_correct : reify_gen_cache. + End ReifyHints. Section rmontred. Context (N R N' : Z) @@ -7012,7 +7196,7 @@ Module MontgomeryReduction. Let bound := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. Definition relax_zrange_of_machine_wordsize - := relax_zrange_gen [machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize; 4 * machine_wordsize]%Z. + := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize; 4 * machine_wordsize]%Z. Local Arguments relax_zrange_of_machine_wordsize / . Let rw := rweight machine_wordsize. @@ -7033,7 +7217,7 @@ Module MontgomeryReduction. else res. Notation BoundsPipeline_correct in_bounds out_bounds op - := (fun (rop : Expr (type.reify_type_of op%function)) rv Hrop + := (fun rv (rop : Expr (type.reify_type_of op%function)) E Hrop HE => @Pipeline.BoundsPipeline_correct_trans true (* DCE *) relax_zrange @@ -7043,7 +7227,7 @@ Module MontgomeryReduction. in_bounds out_bounds op - Hrop rv) + Hrop rv E HE) (only parsing). Definition rmontred_correct @@ -7052,9 +7236,9 @@ Module MontgomeryReduction. bound (montred' N R N' (Interp rw) (Interp rw_half) 2). - Notation type_of_strip_2arrow := ((fun s s' (d : Prop) (_ : s -> s' -> d) => d) _ _ _). + Notation type_of_strip_5arrow := ((fun (d : Prop) (_ : forall A B C D E, d) => d) _). Definition rmontred_correctT rv : Prop - := exists rop, type_of_strip_2arrow (@rmontred_correct rop rv). + := type_of_strip_5arrow (@rmontred_correct rv). End rmontred. End MontgomeryReduction. @@ -7071,217 +7255,166 @@ Module Montgomery256. Derive montred256 SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256) As montred256_correct. - Proof. - eexists; eapply MontgomeryReduction.rmontred_correct with (machine_wordsize:=machine_wordsize). - Time do_inline_cache_reify ltac:(fun _ => idtac). - cbv [Pipeline.BoundsPipeline]. - set (k := PartialReduce _). - cbv [CheckedPartialReduceWithBounds1]. - Time set (k' := PartialReduceWithBounds1 _ _). - Time lazy in k'. - Print fold_left. - Time lazy; reflexivity. - Time lazy -[Let_In k']. - | (* Doing [lazy] is twice as slow as doing [lazy -[Let_In]; lazy]. - This is because the bounds pipeline does [dlet E : Expr := (reduced - thing) in let b := extract_bounds E in ...]. If we allow [lazy] to - unfold [Let_In] before it fully reduces the function (function, - because [Expr := forall var, @expr var]), then there is no sharing - between the partial reduction in bounds extraction and the partial - reduction in the return value. So we force [lazy] to fully reduce - the argument first, and only then permit [lazy] to inline it. This - is slightly slower than doing bounds analysis in a non-PHOAS - representation; we spend about 3%-5% of the overall time doing - bounds extraction, and fully reducing the bounds extraction - expression before plugging in arguments costs a bit more. However, - it's still reasonably fast, and the code is much simpler when - [Interp] always succeeds rather than returning [option]. *) - lazy -[Let_In]; lazy; reflexivity ]. - - Time solve_rmontred_nocache machine_wordsize. - eapply MontgomeryReduction.rmontred_correct. - cbv [MontgomeryReduction.rmontred]. - cbv [Pipeline.BoundsPipeline]. - cbv [CheckedPartialReduceWithBounds1]. - set (k := PartialReduceWithBounds1 _ _). - Timeout 10 Time lazy in k. - Time solve_rmontred(). Time Qed. + Proof. Time solve_rmontred machine_wordsize. Time Qed. Import PrintingNotations. - Open Scope nexpr_scope. + Open Scope expr_scope. Set Printing Width 100000. + Print montred256. - (* - expr_let 3 := (uint128)(fst @@ x_1 >> 128) in - expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in - expr_let 5 := MUL_256 @@ (x_3, (79228162514264337593543950337)) in - expr_let 7 := ((uint128)x_5 & 340282366920938463463374607431768211455) in - expr_let 8 := MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) in - expr_let 10 := ((uint128)x_8 & 340282366920938463463374607431768211455) in - expr_let 11 := (uint128)(x_10 << 128) in - expr_let 12 := (uint128)(x_7 << 128) in - expr_let 17 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in - expr_let 18 := ADD_128 @@ (x_11, x_12) in - expr_let 19 := ADD_256 @@ (x_17, fst @@ x_18) in - expr_let 43 := (uint128)(fst @@ x_19 >> 128) in - expr_let 44 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in - expr_let 45 := MUL_256 @@ (x_43, (79228162514264337593543950335)) in - expr_let 46 := (uint128)(x_45 >> 128) in - expr_let 47 := ((uint128)x_45 & 340282366920938463463374607431768211455) in - expr_let 48 := MUL_256 @@ (x_44, (340282366841710300967557013911933812736)) in - expr_let 49 := (uint128)(x_48 >> 128) in - expr_let 50 := ((uint128)x_48 & 340282366920938463463374607431768211455) in - expr_let 51 := (uint128)(x_50 << 128) in - expr_let 52 := (uint128)(x_47 << 128) in - expr_let 57 := MUL_256 @@ (x_44, (79228162514264337593543950335)) in - expr_let 58 := ADD_128 @@ (x_51, x_52) in - expr_let 59 := ADD_256 @@ (x_57, fst @@ x_58) in - expr_let 60 := snd @@ x_59 +₁₂₈ snd @@ x_58 in - expr_let 67 := MUL_256 @@ (x_43, (340282366841710300967557013911933812736)) in - expr_let 69 := ADD_256 @@ (x_46, x_67) in - expr_let 70 := ADD_256 @@ (x_49, fst @@ x_69) in - expr_let 80 := ADD_256 @@ (x_60, fst @@ x_70) in - expr_let 83 := ADD_256 @@ (fst @@ x_1, fst @@ x_59) in - expr_let 84 := ADDC_256 @@ (snd @@ x_83, snd @@ x_1, fst @@ x_80) in - expr_let 85 := SELC @@ (snd @@ x_84, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 86 := fst @@ (SUB_256 @@ (fst @@ x_84, x_85)) in - ADDM @@ (x_86, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) - : expr uint256 - *) + (*montred256 = fun var : type -> Type => λ v : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + expr_let v0 := (uint128)(v₁ >> 128) in + expr_let v1 := ((uint128)(v₁) & 340282366920938463463374607431768211455) in + expr_let v2 := ((uint128)(79228162514264337593543950337 *₂₅₆ v0) & 340282366920938463463374607431768211455) in + expr_let v3 := ((uint128)(340282366841710300986003757985643364352 *₂₅₆ v1) & 340282366920938463463374607431768211455) in + expr_let v4 := ADD_256 ((uint256)(v3 << 128), (uint256)(v2 << 128)) in + expr_let v5 := ADD_256 (79228162514264337593543950337 *₂₅₆ v1, v4₁) in + expr_let v6 := (uint128)(v5₁ >> 128) in + expr_let v7 := ((uint128)(v5₁) & 340282366920938463463374607431768211455) in + expr_let v8 := (uint128)(79228162514264337593543950335 *₂₅₆ v6 >> 128) in + expr_let v9 := ((uint128)(79228162514264337593543950335 *₂₅₆ v6) & 340282366920938463463374607431768211455) in + expr_let v10 := (uint128)(340282366841710300967557013911933812736 *₂₅₆ v7 >> 128) in + expr_let v11 := ((uint128)(340282366841710300967557013911933812736 *₂₅₆ v7) & 340282366920938463463374607431768211455) in + expr_let v12 := ADD_256 ((uint256)(v11 << 128), (uint256)(v9 << 128)) in + expr_let v13 := ADD_256 (79228162514264337593543950335 *₂₅₆ v7, v12₁) in + expr_let v14 := v13₂ +₁₂₈ v12₂ in + expr_let v15 := ADD_256 (v8, 340282366841710300967557013911933812736 *₂₅₆ v6) in + expr_let v16 := ADD_256 (v10, v15₁) in + expr_let v17 := ADD_256 (v14, v16₁) in + expr_let v18 := ADD_256 (v₁, v13₁) in + expr_let v19 := ADDC_256 (v18₂, v₂, v17₁) in + expr_let v20 := SELC (v19₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let v21 := Z.cast uint256 @@ (fst @@ SUB_256 (v19₁, v20)) in + ADDM (v21, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) + : Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z) +*) End Montgomery256. (* Extra-specialized ad-hoc pretty-printing *) Module Montgomery256PrintingNotations. Export ident. - Export BoundsAnalysis.ident. - Export BoundsAnalysis.type.Notations. - Export BoundsAnalysis.Indexed.expr.Notations. - Export BoundsAnalysis.ident.Notations. - Import BoundsAnalysis.type. - Import BoundsAnalysis.Indexed.expr. - Import BoundsAnalysis.ident. - Open Scope btype_scope. + Open Scope expr_scope. + Open Scope ctype_scope. Notation "'RegMod'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 115792089210356248762697446949407573530086143415290314195533631308867097853951; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope. + (AppIdent + (primitive 115792089210356248762697446949407573530086143415290314195533631308867097853951) + TT) (only printing, at level 9) : expr_scope. Notation "'RegPinv'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 115792089210356248768974548684794254293921932838497980611635986753331132366849; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope. + (AppIdent + (primitive 115792089210356248768974548684794254293921932838497980611635986753331132366849) + TT) (only printing, at level 9) : expr_scope. Notation "'RegZero'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 0; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope. - Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : nexpr_scope. + (AppIdent + (primitive 0) + TT) (only printing, at level 9) : expr_scope. + Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : expr_scope. Notation "'Lower128{RegMod}'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 79228162514264337593543950335; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope. + (AppIdent + (primitive 79228162514264337593543950335) + TT) (only printing, at level 9) : expr_scope. Notation "'RegMod' '<<' '128'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 340282366841710300967557013911933812736; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9, format "'RegMod' '<<' '128'") : nexpr_scope. + (AppIdent + (primitive 340282366841710300967557013911933812736) + TT) (only printing, at level 9, format "'RegMod' '<<' '128'") : expr_scope. Notation "'Lower128{RegPinv}'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 79228162514264337593543950337; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope. + (AppIdent + (primitive 79228162514264337593543950337) + TT) (only printing, at level 9) : expr_scope. Notation "'RegPinv' '>>' '128'" := - (BoundsAnalysis.Indexed.expr.AppIdent - (primitive {| BoundsAnalysis.type.value := 340282366841710300986003757985643364352; BoundsAnalysis.type.value_bounded := _ |}) - BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9, format "'RegPinv' '>>' '128'") : nexpr_scope. + (AppIdent + (primitive 340282366841710300986003757985643364352) + TT) (only printing, at level 9, format "'RegPinv' '>>' '128'") : expr_scope. Notation "'uint256'" - := (BoundsAnalysis.type.ZBounded 0 115792089237316195423570985008687907853269984665640564039457584007913129639935) : btype_scope. + := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : ctype_scope. Notation "'uint128'" - := (BoundsAnalysis.type.ZBounded 0 340282366920938463463374607431768211455) : btype_scope. - Notation "$r n" := (BoundsAnalysis.Indexed.expr.Var _ n) (at level 10, format "$r n") : nexpr_scope. - Notation "$r n '_lo'" := (fst @@ (BoundsAnalysis.Indexed.expr.Var (BoundsAnalysis.type.prod _ _) n))%nexpr (at level 10, format "$r n _lo") : nexpr_scope. - Notation "$r n '_hi'" := (snd @@ (BoundsAnalysis.Indexed.expr.Var (BoundsAnalysis.type.prod _ _) n))%nexpr (at level 10, format "$r n _hi") : nexpr_scope. - Notation "'c.Mul128x128(' '$r' n ',' x ',' y ');' f" := - (expr_let n := mul _ _ uint256 @@ (x, y) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' f" := - (expr_let n := shiftl _ _ count @@ (mul _ _ uint256 @@ (x, y)) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : nexpr_scope. - Notation "'c.Add256(' '$r' n ',' x ',' y ');' f" := - (expr_let n := add_get_carry_concrete _ _ uint256 _ $R @@ (x, y) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.Add128(' '$r' n ',' x ',' y ');' f" := - (expr_let n := add_get_carry_concrete _ _ uint128 _ $R @@ (x, y) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.Add64(' '$r' n ',' x ',' y ');' f" := - (expr_let n := add _ _ uint128 @@ (x, y) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.Addc(' '$r' n ',' x ',' y ');' f" := - (expr_let n := add_with_get_carry_concrete _ _ _ uint256 _ $R @@ (_, x, y) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.Selc(' '$r' n ',' y ',' z ');' f" := - (expr_let n := zselect _ _ _ uint256 @@ (_, y, z) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$r' n ',' y ',' z ');' ']' '//' f") : nexpr_scope. - Notation "'c.Sub(' '$r' n ',' x ',' y ');' f" := - (expr_let n := fst @@ (sub_get_borrow_concrete _ _ uint256 _ $R @@ (x, y)) in - f)%nexpr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope. + := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : ctype_scope. + Notation "$ n" := (Var n) (at level 10, format "$ n") : expr_scope. + Notation "$ n" := (Z.cast _ @@ Var n) (at level 10, format "$ n") : expr_scope. + Notation "$ n '_lo'" := (fst @@ (Var n))%expr (at level 10, format "$ n _lo") : expr_scope. + Notation "$ n '_hi'" := (snd @@ (Var n))%expr (at level 10, format "$ n _hi") : expr_scope. + Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Var n)))%expr (at level 10, format "$ n _lo") : expr_scope. + Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Var n)))%expr (at level 10, format "$ n _hi") : expr_scope. + Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _lo") : expr_scope. + Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _hi") : expr_scope. + Notation "'c.Mul128x128(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast uint256 @@ (Z.mul @@ (x, y)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' f" := + (expr_let n := Z.cast _ @@ (Z.shiftl count @@ (Z.cast uint256 @@ (Z.mul @@ (x, y)))) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : expr_scope. + Notation "'c.Add256(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Add128(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Add64(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast uint128 @@ (Z.add @@ (x, y)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Addc(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (_, x, y)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Selc(' '$' n ',' y ',' z ');' f" := + (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (_, y, z)) in + f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' y ',' z ');' ']' '//' f") : expr_scope. + Notation "'c.Sub(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in + f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope. Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := - (add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope. - Notation "'c.ShiftR(' '$r' n ',' x ',' y ');' f" := - (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.ShiftL(' '$r' n ',' x ',' y ');' f" := - (expr_let n := (shiftl _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. - Notation "'c.Lower128(' '$r' n ',' x ');' f" := - (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope. + (Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope. + Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast _ @@ (Z.shiftr y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" := + (expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope. + Notation "'c.Lower128(' '$' n ',' x ');' f" := + (expr_let n := Z.cast _ @@ (Z.land 340282366920938463463374607431768211455 @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$' n ',' x ');' ']' '//' f") : expr_scope. Notation "'Lower128'" - := ((land uint256 uint128 340282366920938463463374607431768211455)) + := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455)) (at level 10, only printing, format "Lower128") - : nexpr_scope. + : expr_scope. Notation "( v << count )" - := ((shiftl _ _ count @@ v)%nexpr) + := (Z.cast _ @@ (Z.shiftl count @@ v)%expr) (format "( v << count )") - : nexpr_scope. + : expr_scope. Notation "( x >> count )" - := ((shiftr _ _ count @@ x)%nexpr) + := (Z.cast _ @@ (Z.shiftr count @@ x)%expr) (format "( x >> count )") - : nexpr_scope. + : expr_scope. + Notation "x * y" + := (Z.cast uint256 @@ (Z.mul @@ (x, y))) + : expr_scope. End Montgomery256PrintingNotations. Import Montgomery256PrintingNotations. -Local Open Scope nexpr_scope. +Local Open Scope expr_scope. Print Montgomery256.montred256. (* -c.ShiftR($r3, $r1_lo, 128); -c.Lower128($r4, $r1_lo); -c.Mul128x128($r5, $r3, Lower128{RegPinv}); -c.Lower128($r7, $r5); -c.Mul128x128($r8, $r4, RegPinv >> 128); -c.Lower128($r10, $r8); -c.ShiftL($r11, $r10, 128); -c.ShiftL($r12, $r7, 128); -c.Mul128x128($r17, $r4, Lower128{RegPinv}); -c.Add128($r18, $r11, $r12); -c.Add256($r19, $r17, $r18_lo); -c.ShiftR($r43, $r19_lo, 128); -c.Lower128($r44, $r19_lo); -c.Mul128x128($r45, $r43, Lower128{RegMod}); -c.ShiftR($r46, $r45, 128); -c.Lower128($r47, $r45); -c.Mul128x128($r48, $r44, RegMod << 128); -c.ShiftR($r49, $r48, 128); -c.Lower128($r50, $r48); -c.ShiftL($r51, $r50, 128); -c.ShiftL($r52, $r47, 128); -c.Mul128x128($r57, $r44, Lower128{RegMod}); -c.Add128($r58, $r51, $r52); -c.Add256($r59, $r57, $r58_lo); -c.Add64($r60, $r59_hi, $r58_hi); -c.Mul128x128($r67, $r43, RegMod << 128); -c.Add256($r69, $r46, $r67); -c.Add256($r70, $r49, $r69_lo); -c.Add256($r80, $r60, $r70_lo); -c.Add256($r83, $r1_lo, $r59_lo); -c.Addc($r84, $r1_hi, $r80_lo); -c.Selc($r85,RegZero, RegMod); -c.Sub($r86, $r84_lo, $r85); -c.AddM($ret, $r86, RegZero, RegMod); - : expr uint256 +c.ShiftR($v0, $v_lo, 128); +c.Lower128($v1, $v_lo); +c.Lower128($v2, Lower128{RegPinv} * $v0); +c.Lower128($v3, RegPinv >> 128 * $v1); +c.Add256($v4, ($v3 << 128), ($v2 << 128)); +c.Add256($v5, Lower128{RegPinv} * $v1, $v4_lo); +c.ShiftR($v6, $v5_lo, 128); +c.Lower128($v7, $v5_lo); +c.ShiftR($v8, Lower128{RegMod} * $v6, 128); +c.Lower128($v9, Lower128{RegMod} * $v6); +c.ShiftR($v10, RegMod << 128 * $v7, 128); +c.Lower128($v11, RegMod << 128 * $v7); +c.Add256($v12, ($v11 << 128), ($v9 << 128)); +c.Add256($v13, Lower128{RegMod} * $v7, $v12_lo); +c.Add64($v14, $v13_hi, $v12_hi); +c.Add256($v15, $v8, RegMod << 128 * $v6); +c.Add256($v16, $v10, $v15_lo); +c.Add256($v17, $v14, $v16_lo); +c.Add256($v18, $v_lo, $v13_lo); +c.Addc($v19, $v_hi, $v17_lo); +c.Selc($v20,RegZero, RegMod); +c.Sub($v21, $v19_lo, $v20); +c.AddM($ret, $v21, RegZero, RegMod); + : Expr + (type.type_primitive type.Z * type.type_primitive type.Z -> + type.type_primitive type.Z) *) |