aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-03-18 21:54:26 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-03-19 14:17:26 -0400
commit2ee5a1b54d1fe45f621e0f77f3446e348e4c1d19 (patch)
tree01764af2a901e0bf0a3260d79d41492c39389f5a /src
parent9a35ebe478cb3e621a7a4eabf4d88d007cc7128e (diff)
Add support for Z*Z casts, get montred working
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v835
1 files changed, 484 insertions, 351 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 3b935f3df..1b58ede58 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -2021,6 +2021,7 @@ Module Compilers.
| Z_zselect : ident (Z * Z * Z) Z
| Z_add_modulo : ident (Z * Z * Z) Z
| Z_cast (range : zrange) : ident Z Z
+ | Z_cast2 (range : zrange * zrange) : ident (Z * Z) (Z * Z)
.
Notation curry0 f
@@ -2043,6 +2044,11 @@ Module Compilers.
Section gen.
Context (cast_outside_of_range : zrange -> BinInt.Z -> BinInt.Z).
+ Definition cast (r : zrange) (x : BinInt.Z)
+ := if (lower r <=? x) && (x <=? upper r)
+ then x
+ else cast_outside_of_range r x.
+
(** Interpret identifiers where the behavior of [Z_cast]
on a value that does not fit in the range is given by
a context variable. (This allows us to treat [Z_cast]
@@ -2089,9 +2095,8 @@ Module Compilers.
| Z_sub_get_borrow_concrete s => curry2 (Z.sub_get_borrow s)
| Z_zselect => curry3 Z.zselect
| Z_add_modulo => curry3 Z.add_modulo
- | Z_cast r => fun x => if (lower r <=? x) && (x <=? upper r)
- then x
- else cast_outside_of_range r x
+ | Z_cast r => cast r
+ | Z_cast2 (r1, r2) => fun '(x1, x2) => (cast r1 x1, cast r2 x2)
end.
End gen.
@@ -2229,6 +2234,7 @@ Module Compilers.
Notation zselect := Z_zselect.
Notation add_modulo := Z_add_modulo.
Notation cast := Z_cast.
+ Notation cast2 := Z_cast2.
End Z.
Module Nat.
@@ -2637,6 +2643,16 @@ Module Compilers.
| None => None
end.
+ Definition invert_Z_cast2 (e : @expr var (type.Z * type.Z)) : option ((zrange * zrange) * @expr var (type.Z * type.Z))
+ := match invert_AppIdent e with
+ | Some (existT s (idc, args))
+ => match idc in ident s t return expr s -> option ((zrange * zrange) * expr (type.Z * type.Z)) with
+ | ident.Z_cast2 r => fun v => Some (r, v)
+ | _ => fun _ => None
+ end args
+ | None => None
+ end.
+
Local Notation list_expr
:= (fun t => match t return Type with
| type.list T => list (expr T)
@@ -3156,6 +3172,7 @@ Module Compilers.
| ident.Z_zselect as idc
| ident.Z_add_modulo as idc
| ident.Z_cast _ as idc
+ | ident.Z_cast2 _ as idc
=> cps_of (Uncurried.expr.default.ident.interp idc)
| ident.Z_mul_split_concrete s
=> cps_of (curry2 (Z.mul_split s))
@@ -3371,6 +3388,7 @@ Module Compilers.
(ident.snd @@ (Var xyk))
@ ((idc : default.ident _ (type.Z * type.Z))
@@ (ident.fst @@ (Var xyk)))
+ | ident.Z_cast2 _ as idc
| ident.Z_mul_split_concrete _ as idc
| ident.Z_add_get_carry_concrete _ as idc
| ident.Z_sub_get_borrow_concrete _ as idc
@@ -3895,14 +3913,14 @@ Module Compilers.
| type.type_primitive x => primitive.option.interp x
| type.prod A B => interp A * interp B
| type.arrow s d => interp s -> interp d
- | type.list A => list (interp A)
+ | type.list A => option (list (interp A))
end.
Fixpoint None {t : type} : interp t
:= match t with
| type.type_primitive x => @primitive.option.None x
| type.prod A B => (@None A, @None B)
| type.arrow s d => fun _ => @None d
- | type.list A => @nil (interp A)
+ | type.list A => Datatypes.None
end.
Fixpoint Some {t : type} : type.interp t -> interp t
:= match t with
@@ -3911,7 +3929,7 @@ Module Compilers.
=> fun x : type.interp A * type.interp B
=> (@Some A (fst x), @Some B (snd x))
| type.arrow s d => fun _ _ => @None d
- | type.list A => List.map (@Some A)
+ | type.list A => fun ls => Datatypes.Some (List.map (@Some A) ls)
end.
Fixpoint is_tighter_than {t} : interp t -> interp t -> bool
:= match t with
@@ -3922,7 +3940,13 @@ Module Compilers.
=> @is_tighter_than A ra ra' && @is_tighter_than B rb rb'
| type.arrow s d => fun _ _ => false
| type.list A
- => fold_andb_map (@is_tighter_than A)
+ => fun ls1 ls2
+ => match ls1, ls2 with
+ | Datatypes.None, Datatypes.None => true
+ | Datatypes.Some _, Datatypes.None => true
+ | Datatypes.None, Datatypes.Some _ => false
+ | Datatypes.Some ls1, Datatypes.Some ls2 => fold_andb_map (@is_tighter_than A) ls1 ls2
+ end
end.
Fixpoint is_bounded_by {t} : interp t -> Compilers.type.interp t -> bool
:= match t return interp t -> Compilers.type.interp t -> bool with
@@ -3933,7 +3957,11 @@ Module Compilers.
=> @is_bounded_by A ra ra' && @is_bounded_by B rb rb'
| type.arrow s d => fun _ _ => false
| type.list A
- => fold_andb_map (@is_bounded_by A)
+ => fun ls1 ls2
+ => match ls1 with
+ | Datatypes.None => false
+ | Datatypes.Some ls1 => fold_andb_map (@is_bounded_by A) ls1 ls2
+ end
end.
Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val
@@ -3954,7 +3982,10 @@ Module Compilers.
| apply conj
| Z.ltb_to_lt; omega
| rewrite @fold_andb_map_map in * ].
- { revert r1 r2 val Htight Hbounds IHt.
+ { lazymatch goal with
+ | [ r1 : list (interp t), r2 : list (type.interp t), val : list (Compilers.type.interp t) |- _ ]
+ => revert r1 r2 val Htight Hbounds IHt
+ end; intros r1 r2 val; revert r1 r2 val.
induction r1, r2, val; cbn; auto with nocore; try congruence; [].
rewrite !Bool.andb_true_iff; intros; destruct_head'_and; split; eauto with nocore. }
Qed.
@@ -4000,12 +4031,18 @@ Module Compilers.
| ident.Z_sub_get_borrow
| ident.Z_modulo
=> fun _ => type.option.None
- | ident.nil t => curry0 (@nil (type.option.interp t))
- | ident.cons t => curry2 (@Datatypes.cons (type.option.interp t))
+ | ident.nil t => curry0 (Some (@nil (type.option.interp t)))
+ | ident.cons t => curry2 (fun a => option_map (@Datatypes.cons (type.option.interp t) a))
| ident.fst A B => @Datatypes.fst (type.option.interp A) (type.option.interp B)
| ident.snd A B => @Datatypes.snd (type.option.interp A) (type.option.interp B)
| ident.List_nth_default_concrete T d n
- => fun ls => @nth_default (type.option.interp T) type.option.None ls n
+ => fun ls
+ => match ls with
+ | Datatypes.Some ls
+ => @nth_default (type.option.interp T) type.option.None ls n
+ | Datatypes.None
+ => type.option.None
+ end
| ident.Z_shiftr _ as idc
| ident.Z_shiftl _ as idc
| ident.Z_opp as idc
@@ -4029,6 +4066,16 @@ Module Compilers.
| Some r => ZRange.intersection r range
| None => range
end
+ | ident.Z_cast2 (r1, r2)
+ => fun '((r1', r2') : option zrange * option zrange)
+ => (Some match r1' with
+ | Some r => ZRange.intersection r r1
+ | None => r1
+ end,
+ Some match r2' with
+ | Some r => ZRange.intersection r r2
+ | None => r2
+ end)
| ident.Z_mul_split_concrete split_at
=> fun '((x, y) : option zrange * option zrange)
=> match x, y with
@@ -4110,7 +4157,7 @@ Module Compilers.
End DefaultValue.
Module partial.
- Notation Zdata := (option zrange).
+ Notation data := ZRange.type.option.interp.
Section value.
Context (var : type -> Type).
Definition value_prestep (value : type -> Type) (t : type)
@@ -4125,12 +4172,10 @@ Module Compilers.
:= match t return Type with
| type.arrow _ _ as t
=> value_prestep value t
- | type.type_primitive type.Z as t
- => Zdata * @expr var t + value_prestep value t
| type.prod _ _ as t
| type.list _ as t
| type.type_primitive _ as t
- => @expr var t + value_prestep value t
+ => data t * @expr var t + value_prestep value t
end%type.
Fixpoint value (t : type)
:= value_step value t.
@@ -4145,6 +4190,38 @@ Module Compilers.
| type.arrow s d => fun _ => @value_default d
| type.list A => inr (@nil (value A))
end.
+
+ Fixpoint data_from_value {t} : value t -> data t
+ := match t return value t -> data t with
+ | type.arrow _ _ as t
+ => fun _ => ZRange.type.option.None
+ | type.prod A B as t
+ => fun v
+ => match v with
+ | inl (data, _) => data
+ | inr (a, b)
+ => (@data_from_value A a, @data_from_value B b)
+ end
+ | type.list A as t
+ => fun v
+ => match v with
+ | inl (data, _) => data
+ | inr ls
+ => Some (List.map (@data_from_value A) ls)
+ end
+ | type.type_primitive type.Z as t
+ => fun v
+ => match v with
+ | inl (data, _) => data
+ | inr v => Some r[v~>v]%zrange
+ end
+ | type.type_primitive _ as t
+ => fun v
+ => match v with
+ | inl (data, _) => data
+ | inr _ => ZRange.type.option.None
+ end
+ end.
End value.
Module expr.
@@ -4154,9 +4231,19 @@ Module Compilers.
: value var t -> @expr var t
:= match t return value var t -> expr t with
| type.prod A B as t
- => fun x : expr t + value var A * value var B
+ => fun x : (data A * data B) * expr t + value var A * value var B
=> match x with
- | inl v => v
+ | inl ((da, db), v)
+ => match A, B return data A -> data B -> expr (A * B) -> expr (A * B) with
+ | type.Z, type.Z
+ => fun da db v
+ => match da, db with
+ | Some r1, Some r2
+ => (ident.Z.cast2 (r1, r2)%core @@ v)%expr
+ | _, _ => v
+ end
+ | _, _ => fun _ _ v => v
+ end da db v
| inr (a, b) => (@reify A a, @reify B b)%expr
end
| type.arrow s d
@@ -4164,24 +4251,24 @@ Module Compilers.
=> Abs (fun x
=> @reify d (f (@reflect s (Var x))))
| type.list A as t
- => fun x : expr t + list (value var A)
+ => fun x : _ * expr t + list (value var A)
=> match x with
- | inl v => v
+ | inl (_, v) => v
| inr v => reify_list (List.map (@reify A) v)
end
| type.type_primitive type.Z as t
- => fun x : Zdata * expr t + type.interp t
+ => fun x : _ * expr t + type.interp t
=> match x with
| inl (Some r, v) => ident.Z.cast r @@ v
| inl (None, v) => v
| inr v => ident.primitive v @@ TT
end%core%expr
| type.type_primitive _ as t
- => fun x : expr t + type.interp t
+ => fun x : _ * expr t + type.interp t
=> match x with
- | inl v => v
+ | inl (_, v) => v
| inr v => ident.primitive v @@ TT
- end%expr
+ end%core%expr
end
with reflect {t : type}
: @expr var t -> value var t
@@ -4191,28 +4278,39 @@ Module Compilers.
=> @reflect d (App f (@reify s x))
| type.prod A B as t
=> fun v : expr t
- => let inr := @inr (expr t) (value_prestep (value var) t) in
- let inl := @inl (expr t) (value_prestep (value var) t) in
+ => let inr := @inr (data t * expr t) (value_prestep (value var) t) in
+ let inl := @inl (data t * expr t) (value_prestep (value var) t) in
match invert_Pair v with
| Some (a, b)
=> inr (@reflect A a, @reflect B b)
| None
- => inl v
+ => inl
+ (match A, B return expr (A * B) -> data (A * B) * expr (A * B) with
+ | type.Z, type.Z
+ => fun v
+ => match invert_Z_cast2 v with
+ | Some (r, v)
+ => (ZRange.type.option.Some (t:=type.Z*type.Z) r, v)
+ | None
+ => (ZRange.type.option.None, v)
+ end
+ | _, _ => fun v => (ZRange.type.option.None, v)
+ end v)
end
| type.list A as t
=> fun v : expr t
- => let inr := @inr (expr t) (value_prestep (value var) t) in
- let inl := @inl (expr t) (value_prestep (value var) t) in
+ => let inr := @inr (data t * expr t) (value_prestep (value var) t) in
+ let inl := @inl (data t * expr t) (value_prestep (value var) t) in
match reflect_list v with
| Some ls
=> inr (List.map (@reflect A) ls)
| None
- => inl v
+ => inl (None, v)
end
| type.type_primitive type.Z as t
=> fun v : expr t
- => let inr' := @inr (Zdata * expr t) (value_prestep (value var) t) in
- let inl' := @inl (Zdata * expr t) (value_prestep (value var) t) in
+ => let inr' := @inr (data t * expr t) (value_prestep (value var) t) in
+ let inl' := @inl (data t * expr t) (value_prestep (value var) t) in
match reflect_primitive v, invert_Z_cast v with
| Some v, _ => inr' v
| None, Some (r, v) => inl' (Some r, v)
@@ -4220,11 +4318,11 @@ Module Compilers.
end
| type.type_primitive _ as t
=> fun v : expr t
- => let inr := @inr (expr t) (value_prestep (value var) t) in
- let inl := @inl (expr t) (value_prestep (value var) t) in
+ => let inr := @inr (data t * expr t) (value_prestep (value var) t) in
+ let inl := @inl (data t * expr t) (value_prestep (value var) t) in
match reflect_primitive v with
| Some v => inr v
- | None => inl v
+ | None => inl (tt, v)
end
end.
End reify.
@@ -4239,6 +4337,8 @@ Module Compilers.
| TT => true
| AppIdent _ _ (ident.fst _ _) args => @is_var_like _ args
| AppIdent _ _ (ident.snd _ _) args => @is_var_like _ args
+ | AppIdent _ _ (ident.Z.cast _) args => @is_var_like _ args
+ | AppIdent _ _ (ident.Z.cast2 _) args => @is_var_like _ args
| Pair A B a b => @is_var_like A a && @is_var_like B b
| AppIdent _ _ _ _ => false
| App _ _ _ _
@@ -4252,7 +4352,7 @@ Module Compilers.
| type.arrow _ _
=> fun x f => f x
| type.list T as t
- => fun (x : expr t + list (value var T)) (f : expr t + list (value var T) -> value var tC)
+ => fun (x : data t * expr t + list (value var T)) (f : data t * expr t + list (value var T) -> value var tC)
=> match x with
| inr ls
=> list_rect
@@ -4267,7 +4367,7 @@ Module Compilers.
| inl e => f (inl e)
end
| type.prod A B as t
- => fun (x : expr t + value var A * value var B) (f : expr t + value var A * value var B -> value var tC)
+ => fun (x : data t * expr t + value var A * value var B) (f : data t * expr t + value var A * value var B -> value var tC)
=> match x with
| inr (a, b)
=> @interp_let_in
@@ -4276,11 +4376,12 @@ Module Compilers.
=> @interp_let_in
_ B b
(fun b => f (inr (a, b))))
- | inl e => partial.expr.reflect (expr_let y := e in partial.expr.reify (f (inl (Var y))))%expr
+ | inl (data, e) => partial.expr.reflect
+ (expr_let y := partial.expr.reify (t:=t) x in
+ partial.expr.reify (f (inl (data, Var y)%core)))%expr
end
- | type.type_primitive type.Z as t
- => fun (x : Zdata * expr t + type.interp t)
- (f : Zdata * expr t + type.interp t -> value var tC)
+ | type.type_primitive _ as t
+ => fun (x : data t * expr t + type.interp t) (f : data t * expr t + type.interp t -> value var tC)
=> match x with
| inl (data, e)
=> if is_var_like e
@@ -4288,74 +4389,73 @@ Module Compilers.
else partial.expr.reflect
(expr_let y := (partial.expr.reify (t:=t) x) in
partial.expr.reify (f (inl (data, Var y)%core)))%expr
- | inr v => f (inr v)
- end
- | type.type_primitive _ as t
- => fun (x : expr t + type.interp t) (f : expr t + type.interp t -> value var tC)
- => match x with
- | inl e
- => match invert_Var e with
- | Some _ => f x
- | None => partial.expr.reflect
- (expr_let y := (partial.expr.reify (t:=t) x) in
- partial.expr.reify (f (inl (Var y))))%expr
- end
| inr v => f (inr v) (* FIXME: do not substitute [S (big stuck term)] *)
end
end.
+ Let default_interp
+ {s d}
+ : ident s d -> value var s -> value var d
+ := match d return ident s d -> value var s -> value var d with
+ | type.arrow _ _
+ => fun idc args => expr.reflect (AppIdent idc (expr.reify args))
+ | _
+ => fun idc args
+ => inl (ZRange.ident.option.interp idc (data_from_value var args),
+ AppIdent idc (expr.reify args))
+ end.
+
(** do partial reduction on identifiers *)
Definition interp {s d} (idc : ident s d) : value var (s -> d)
:= match idc in ident s d return value var (s -> d) with
| ident.Let_In tx tC as idc
- => fun (xf : expr (tx * (tx -> tC)) + value var tx * value var (tx -> tC))
+ => fun (xf : data (tx * (tx -> tC)) * expr (tx * (tx -> tC)) + value var tx * value var (tx -> tC))
=> match xf with
| inr (x, f) => interp_let_in x f
| _ => expr.reflect (AppIdent idc (expr.reify (t:=tx * (tx -> tC)) xf))
end
| ident.nil t
=> fun _ => inr (@nil (value var t))
- | ident.primitive type.Z v
- => fun _ => inr v
| ident.primitive t v
=> fun _ => inr v
| ident.cons t as idc
- => fun (x_xs : expr (t * type.list t) + value var t * (expr (type.list t) + list (value var t)))
- => match x_xs return expr (type.list t) + list (value var t) with
+ => fun (x_xs : data (t * type.list t) * expr (t * type.list t) + value var t * (data (type.list t) * expr (type.list t) + list (value var t)))
+ => match x_xs return data (type.list t) * expr (type.list t) + list (value var t) with
| inr (x, inr xs) => inr (cons x xs)
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=t * type.list t) x_xs))
+ | _
+ => default_interp idc x_xs
end
| ident.fst A B as idc
- => fun x : expr (A * B) + value var A * value var B
+ => fun x : data (A * B) * expr (A * B) + value var A * value var B
=> match x with
| inr x => fst x
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=A*B) x))
+ | _ => default_interp idc x
end
| ident.snd A B as idc
- => fun x : expr (A * B) + value var A * value var B
+ => fun x : data (A * B) * expr (A * B) + value var A * value var B
=> match x with
| inr x => snd x
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=A*B) x))
+ | _ => default_interp idc x
end
| ident.bool_rect T as idc
- => fun (true_case_false_case_b : expr (T * T * type.bool) + (expr (T * T) + value var T * value var T) * (expr type.bool + bool))
+ => fun (true_case_false_case_b : data (T * T * type.bool) * expr (T * T * type.bool) + (data (T * T) * expr (T * T) + value var T * value var T) * (data type.bool * expr type.bool + bool))
=> match true_case_false_case_b with
| inr (inr (true_case, false_case), inr b)
=> @bool_rect (fun _ => value var T) true_case false_case b
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=T*T*type.bool) true_case_false_case_b))
+ | _ => default_interp idc true_case_false_case_b
end
| ident.nat_rect P as idc
- => fun (O_case_S_case_n : expr (P * (type.nat * P -> P) * type.nat) + (expr (P * (type.nat * P -> P)) + value var P * value var (type.nat * P -> P)) * (expr type.nat + nat))
+ => fun (O_case_S_case_n : _ * expr (P * (type.nat * P -> P) * type.nat) + (_ * expr (P * (type.nat * P -> P)) + value var P * value var (type.nat * P -> P)) * (_ * expr type.nat + nat))
=> match O_case_S_case_n with
| inr (inr (O_case, S_case), inr n)
=> @nat_rect (fun _ => value var P)
O_case
(fun n' rec => S_case (inr (inr n', rec)))
n
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=P * (type.nat * P -> P) * type.nat) O_case_S_case_n))
+ | _ => default_interp idc O_case_S_case_n
end
| ident.list_rect A P as idc
- => fun (nil_case_cons_case_ls : expr (P * (A * type.list A * P -> P) * type.list A) + (expr (P * (A * type.list A * P -> P)) + value var P * value var (A * type.list A * P -> P)) * (expr (type.list A) + list (value var A)))
+ => fun (nil_case_cons_case_ls : _ * expr (P * (A * type.list A * P -> P) * type.list A) + (_ * expr (P * (A * type.list A * P -> P)) + value var P * value var (A * type.list A * P -> P)) * (_ * expr (type.list A) + list (value var A)))
=> match nil_case_cons_case_ls with
| inr (inr (nil_case, cons_case), inr ls)
=> @list_rect
@@ -4364,60 +4464,60 @@ Module Compilers.
nil_case
(fun x xs rec => cons_case (inr (inr (x, inr xs), rec)))
ls
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=P * (A * type.list A * P -> P) * type.list A) nil_case_cons_case_ls))
+ | _ => default_interp idc nil_case_cons_case_ls
end
| ident.List.nth_default type.Z as idc
- => fun (default_ls_idx : expr (type.Z * type.list type.Z * type.nat) + (expr (type.Z * type.list type.Z) + (_ * expr type.Z + type.interp type.Z) * (expr (type.list type.Z) + list (value var type.Z))) * (expr type.nat + nat))
+ => fun (default_ls_idx : _ * expr (type.Z * type.list type.Z * type.nat) + (_ * expr (type.Z * type.list type.Z) + (_ * expr type.Z + type.interp type.Z) * (_ * expr (type.list type.Z) + list (value var type.Z))) * (_ * expr type.nat + nat))
=> match default_ls_idx with
| inr (inr (default, inr ls), inr idx)
=> List.nth_default default ls idx
| inr (inr (inr default, ls), inr idx)
- => expr.reflect (AppIdent (ident.List.nth_default_concrete default idx) (expr.reify (t:=type.list type.Z) ls))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=type.Z * type.list type.Z * type.nat) default_ls_idx))
+ => default_interp (ident.List.nth_default_concrete default idx) ls
+ | _ => default_interp idc default_ls_idx
end
| ident.List.nth_default (type.type_primitive A) as idc
- => fun (default_ls_idx : expr (A * type.list A * type.nat) + (expr (A * type.list A) + (expr A + type.interp A) * (expr (type.list A) + list (value var A))) * (expr type.nat + nat))
+ => fun (default_ls_idx : _ * expr (A * type.list A * type.nat) + (_ * expr (A * type.list A) + (_ * expr A + type.interp A) * (_ * expr (type.list A) + list (value var A))) * (_ * expr type.nat + nat))
=> match default_ls_idx with
| inr (inr (default, inr ls), inr idx)
=> List.nth_default default ls idx
| inr (inr (inr default, ls), inr idx)
- => expr.reflect (AppIdent (ident.List.nth_default_concrete default idx) (expr.reify (t:=type.list A) ls))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=A * type.list A * type.nat) default_ls_idx))
+ => default_interp (ident.List.nth_default_concrete default idx) ls
+ | _ => default_interp idc default_ls_idx
end
| ident.List.nth_default A as idc
- => fun (default_ls_idx : expr (A * type.list A * type.nat) + (expr (A * type.list A) + value var A * (expr (type.list A) + list (value var A))) * (expr type.nat + nat))
+ => fun (default_ls_idx : _ * expr (A * type.list A * type.nat) + (_ * expr (A * type.list A) + value var A * (_ * expr (type.list A) + list (value var A))) * (_ * expr type.nat + nat))
=> match default_ls_idx with
| inr (inr (default, inr ls), inr idx)
=> List.nth_default default ls idx
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=A * type.list A * type.nat) default_ls_idx))
+ | _ => default_interp idc default_ls_idx
end
| ident.List.nth_default_concrete A default idx as idc
- => fun (ls : expr (type.list A) + list (value var A))
+ => fun (ls : _ * expr (type.list A) + list (value var A))
=> match ls with
| inr ls
=> List.nth_default (expr.reflect (t:=A) (AppIdent (ident.primitive default) TT)) ls idx
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=type.list A) ls))
+ | _ => default_interp idc ls
end
| ident.Z_mul_split as idc
- => fun (x_y_z : (expr (type.Z * type.Z * type.Z) +
- (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
- => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) +
+ (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
+ => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr (inr x, inr y), inr z) =>
let result := ident.interp idc (x, y, z) in
inr (inr (fst result), inr (snd result))
| inr (inr (inr x, y), z)
- => expr.reflect (AppIdent (ident.Z.mul_split_concrete x) (expr.reify (t:=type.Z*type.Z) (inr (y, z))))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z))
+ => default_interp (ident.Z.mul_split_concrete x) (inr (y, z))
+ | _ => default_interp idc x_y_z
end
| ident.Z_add_get_carry as idc
- => fun (x_y_z : (expr (type.Z * type.Z * type.Z) +
- (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
- => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) +
+ (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
+ => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr (inr x, inr y), inr z) =>
let result := ident.interp idc (x, y, z) in
inr (inr (fst result), inr (snd result))
| inr (inr (inr x, y), z)
- => let default := expr.reflect (AppIdent (ident.Z.add_get_carry_concrete x) (expr.reify (t:=type.Z*type.Z) (inr (y, z)))) in
+ => let default := default_interp (ident.Z.add_get_carry_concrete x) (inr (y, z)) in
match (y, z) with
| (inr xx, inl e)
| (inl e, inr xx)
@@ -4426,45 +4526,45 @@ Module Compilers.
else default
| _ => default
end
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z))
+ | _ => default_interp idc x_y_z
end
| ident.Z_add_with_get_carry as idc
- => fun (x_y_z_a : (expr (_ * _ * _ * _) +
- (expr (_ * _ * _) +
- (expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) *
+ => fun (x_y_z_a : (_ * expr (_ * _ * _ * _) +
+ (_ * expr (_ * _ * _) +
+ (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) *
(_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type)
- => match x_y_z_a return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => match x_y_z_a return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr (inr (inr x, inr y), inr z), inr a) =>
let result := ident.interp idc (x, y, z, a) in
inr (inr (fst result), inr (snd result))
| inr (inr (inr (inr x, y), z), a)
- => expr.reflect (AppIdent (ident.Z.add_with_get_carry_concrete x) (expr.reify (t:=type.Z*type.Z*type.Z) (inr (inr (y, z), a))))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_*_) x_y_z_a))
+ => default_interp (ident.Z.add_with_get_carry_concrete x) (inr (inr (y, z), a))
+ | _ => default_interp idc x_y_z_a
end
| ident.Z_sub_get_borrow as idc
- => fun (x_y_z : (expr (type.Z * type.Z * type.Z) +
- (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
- => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) +
+ (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
+ => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr (inr x, inr y), inr z) =>
let result := ident.interp idc (x, y, z) in
inr (inr (fst result), inr (snd result))
| inr (inr (inr x, y), z)
- => expr.reflect (AppIdent (ident.Z.sub_get_borrow_concrete x) (expr.reify (t:=type.Z*type.Z) (inr (y, z))))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z))
+ => default_interp (ident.Z.sub_get_borrow_concrete x) (inr (y, z))
+ | _ => default_interp idc x_y_z
end
| ident.Z_mul_split_concrete _ as idc
| ident.Z.sub_get_borrow_concrete _ as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
- => match x_y return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr x, inr y) =>
let result := ident.interp idc (x, y) in
inr (inr (fst result), inr (snd result))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y))
+ | _ => default_interp idc x_y
end
| ident.Z.add_get_carry_concrete _ as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
- => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in
- match x_y return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => let default := default_interp idc x_y in
+ match x_y return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr x, inr y) =>
let result := ident.interp idc (x, y) in
inr (inr (fst result), inr (snd result))
@@ -4476,26 +4576,26 @@ Module Compilers.
| _ => default
end
| ident.Z.add_with_get_carry_concrete _ as idc
- => fun (x_y_z : (expr (type.Z * type.Z * type.Z) +
- (expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
- => match x_y_z return (expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
+ => fun (x_y_z : (_ * expr (type.Z * type.Z * type.Z) +
+ (_ * expr (type.Z * type.Z) + (_ * expr type.Z + Z) * (_ * expr type.Z + Z)) * (_ * expr type.Z + Z))%type)
+ => match x_y_z return (_ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) with
| inr (inr (inr x, inr y), inr z) =>
let result := ident.interp idc (x, y, z) in
inr (inr (fst result), inr (snd result))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z))
+ | _ => default_interp idc x_y_z
end
| ident.pred as idc
| ident.Nat_succ as idc
- => fun x : expr _ + type.interp _
- => match x return expr _ + type.interp _ with
+ => fun x : _ * expr _ + type.interp _
+ => match x return _ * expr _ + type.interp _ with
| inr x => inr (ident.interp idc x)
- | inl x => expr.reflect (AppIdent idc x)
+ | _ => default_interp idc x
end
| ident.Z_of_nat as idc
- => fun x : expr _ + type.interp _
+ => fun x : _ * expr _ + type.interp _
=> match x return _ * expr _ + type.interp _ with
| inr x => inr (ident.interp idc x)
- | inl x => expr.reflect (AppIdent idc x)
+ | _ => default_interp idc x
end
| ident.Z_opp as idc
=> fun x : _ * expr _ + type.interp _
@@ -4522,36 +4622,36 @@ Module Compilers.
| ident.Z_eqb as idc
| ident.Z_leb as idc
| ident.Z_pow as idc
- => fun (x_y : expr (_ * _) + (_ + type.interp _) * (_ + type.interp _))
+ => fun (x_y : data (_ * _) * expr (_ * _) + (_ + type.interp _) * (_ + type.interp _))
=> match x_y return _ + type.interp _ with
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y))
+ | _ => default_interp idc x_y
end
| ident.Z_div as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
- => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => let default := default_interp idc x_y in
match x_y return _ * expr _ + type.interp _ with
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
| inr (x, inr y)
=> if Z.eqb y (2^Z.log2 y)
- then expr.reflect (AppIdent (ident.Z.shiftr (Z.log2 y)) (expr.reify (t:=type.Z) x))
+ then default_interp (ident.Z.shiftr (Z.log2 y)) x
else default
| _ => default
end
| ident.Z_modulo as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
- => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => let default := default_interp idc x_y in
match x_y return _ * expr _ + type.interp _ with
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
| inr (x, inr y)
=> if Z.eqb y (2^Z.log2 y)
- then expr.reflect (AppIdent (ident.Z.land (y-1)) (expr.reify (t:=type.Z) x))
+ then default_interp (ident.Z.land (y-1)) x
else default
| _ => default
end
| ident.Z_mul as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
- => let default := expr.reflect (AppIdent idc (expr.reify (t:=_*_) x_y)) in
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => let default := default_interp idc x_y in
match x_y return _ * expr _ + type.interp _ with
| inr (inr x, inr y) => inr (ident.interp idc (x, y))
| inr (inr x, inl (data, e) as y)
@@ -4574,7 +4674,7 @@ Module Compilers.
| inl _ => default
end
| ident.Z_add as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
=> let default0 := AppIdent idc (expr.reify (t:=_*_) x_y) in
let default := expr.reflect default0 in
match x_y return _ * expr _ + type.interp _ with
@@ -4608,7 +4708,7 @@ Module Compilers.
| inl _ => default
end
| ident.Z_sub as idc
- => fun (x_y : expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => fun (x_y : _ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
=> let default0 := AppIdent idc (expr.reify (t:=_*_) x_y) in
let default := expr.reflect default0 in
match x_y return _ * expr _ + type.interp _ with
@@ -4641,11 +4741,11 @@ Module Compilers.
end
| ident.Z_zselect as idc
| ident.Z_add_modulo as idc
- => fun (x_y_z : (expr (_ * _ * _) +
- (expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type)
+ => fun (x_y_z : (_ * expr (_ * _ * _) +
+ (_ * expr (_ * _) + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _)) * (_ * expr _ + type.interp _))%type)
=> match x_y_z return _ * expr _ + type.interp _ with
| inr (inr (inr x, inr y), inr z) => inr (ident.interp idc (x, y, z))
- | _ => expr.reflect (AppIdent idc (expr.reify (t:=_*_*_) x_y_z))
+ | _ => default_interp idc x_y_z
end
| ident.Z_cast r as idc
=> fun (x : _ * expr _ + type.interp _)
@@ -4654,6 +4754,24 @@ Module Compilers.
| inl (data, e)
=> inl (ZRange.ident.option.interp idc data, e)
end
+ | ident.Z_cast2 (r1, r2) as idc
+ => fun (x : _ * expr _ + (_ * expr _ + type.interp _) * (_ * expr _ + type.interp _))
+ => match x with
+ | inr (inr a, inr b)
+ => inr (inr (ident.interp (ident.Z.cast r1) a),
+ inr (ident.interp (ident.Z.cast r2) b))
+ | inr (inr a, inl (r2', b))
+ => inr (inr (ident.interp (ident.Z.cast r1) a),
+ inl (ZRange.ident.option.interp (ident.Z.cast r2) r2', b))
+ | inr (inl (r1', a), inr b)
+ => inr (inl (ZRange.ident.option.interp (ident.Z.cast r1) r1', a),
+ inr (ident.interp (ident.Z.cast r2) b))
+ | inr (inl (r1', a), inl (r2', b))
+ => inr (inl (ZRange.ident.option.interp (ident.Z.cast r1) r1', a),
+ inl (ZRange.ident.option.interp (ident.Z.cast r2) r2', b))
+ | inl (data, e)
+ => inl (ZRange.ident.option.interp idc data, e)
+ end
end.
End interp.
End ident.
@@ -4712,37 +4830,74 @@ Module Compilers.
| type.type_primitive t => fun _ => id
| type.prod A B
=> fun '((ra, rb) : ZRange.type.option.interp A * ZRange.type.option.interp B)
- (e : expr _ + partial.value var A * partial.value var B)
+ (e : _ * expr _ + partial.value var A * partial.value var B)
=> match e with
| inr (a, b)
=> inr (@extend_with_obounds A ra a,
@extend_with_obounds B rb b)
- | inl e
+ | inl ((dataa, datab), e)
=> if partial.ident.is_var_like e
then inr (@extend_with_obounds A ra (partial.expr.reflect (AppIdent ident.fst e)),
@extend_with_obounds B rb (partial.expr.reflect (AppIdent ident.snd e)))
- else inl e
+ else inl
+ (match A, B return ZRange.type.option.interp A -> ZRange.type.option.interp B -> data A -> data B -> expr (A * B) -> data (A * B) * expr (A * B) with
+ | type.Z, type.Z
+ => fun ra rb da db e
+ => let da'
+ := match ra with
+ | Some ra
+ => ZRange.ident.option.interp
+ (ident.Z.cast ra) da
+ | None => da
+ end in
+ let db'
+ := match rb with
+ | Some rb
+ => ZRange.ident.option.interp
+ (ident.Z.cast rb) db
+ | None => db
+ end in
+ ((da', db'), e)
+ | _, _
+ => fun _ _ da db e => ((da, db), e)
+ end ra rb dataa datab e)
end
| type.arrow s d => fun _ => id
| type.list A
- => fun (ls : Datatypes.list (ZRange.type.option.interp A))
- (e : expr _ + list (partial.value var A))
- => match e with
- | inl e
- => match A return (ZRange.type.option.interp A -> partial.value var A -> partial.value var A)
- -> Datatypes.list (ZRange.type.option.interp A)
- -> expr (type.list A)
- -> partial.value var (type.list A)
- with
- | type.type_primitive A
- => fun extend_with_obounds ls e
- => inr (extend_list_expr_with_obounds
- extend_with_obounds 0 ls e)
- | A'
- => fun _ _ e => inl e
- end (@extend_with_obounds A) ls e
- | inr e => inr (extend_concrete_list_with_obounds
- (@extend_with_obounds A) ls e)
+ => fun (ls : option (Datatypes.list (ZRange.type.option.interp A)))
+ (e : data _ * expr _ + list (partial.value var A))
+ => match ls with
+ | None => e
+ | Some ls
+ =>
+ match e with
+ | inl (data, e)
+ => match A return (ZRange.type.option.interp A -> partial.value var A -> partial.value var A)
+ -> Datatypes.list (ZRange.type.option.interp A)
+ -> option (Datatypes.list (ZRange.type.option.interp A))
+ -> expr (type.list A)
+ -> partial.value var (type.list A)
+ with
+ | type.type_primitive A
+ => fun extend_with_obounds ls data e
+ => match data with
+ | Some data
+ => inr
+ (extend_concrete_list_with_obounds
+ extend_with_obounds ls
+ (extend_list_expr_with_obounds
+ extend_with_obounds 0 data e))
+ | None
+ => inr (extend_list_expr_with_obounds
+ extend_with_obounds 0 ls e)
+ end
+ | A'
+ (* N.B. We clobber the existing bounds here, rather than fusing them *)
+ => fun _ ls data e => inl (Some ls, e)
+ end (@extend_with_obounds A) ls data e
+ | inr e => inr (extend_concrete_list_with_obounds
+ (@extend_with_obounds A) ls e)
+ end
end
end.
Definition extend_with_bounds {t}
@@ -4761,10 +4916,10 @@ Module Compilers.
| ident.Z_cast range => fun _ => Some range
| ident.primitive type.Z v
=> fun _ => Some r[v~>v]%zrange
- | ident.nil _ => fun _ => nil
+ | ident.nil _ => fun _ => Some nil
| ident.cons t
- => fun '((x, xs) : ZRange.type.option.interp t * list (ZRange.type.option.interp t))
- => cons x xs
+ => fun '((x, xs) : ZRange.type.option.interp t * option (list (ZRange.type.option.interp t)))
+ => option_map (cons x) xs
| _ => fun _ => ZRange.type.option.None
end.
End ident.
@@ -4798,7 +4953,10 @@ Module Compilers.
Section partial_reduce.
Context {var : type -> Type}.
- Fixpoint partial_reduce' {t} (e : @expr (partial.value var) t)
+ Definition partial_reduce'_step
+ (partial_reduce' : forall {t} (e : @expr (partial.value var) t),
+ partial.value var t)
+ {t} (e : @expr (partial.value var) t)
: partial.value var t
:= match e in expr.expr t return partial.value var t with
| Var t v => v
@@ -4808,6 +4966,9 @@ Module Compilers.
| App s d f x => @partial_reduce' _ f (@partial_reduce' _ x)
| Abs s d f => fun x => @partial_reduce' d (f x)
end.
+ Fixpoint partial_reduce' {t} (e : @expr (partial.value var) t)
+ : partial.value var t
+ := @partial_reduce'_step (@partial_reduce') t e.
Definition partial_reduce {t} (e : @expr (partial.value var) t) : @expr var t
:= partial.expr.reify (@partial_reduce' t e).
@@ -4840,6 +5001,12 @@ Module Compilers.
| Some r => AppIdent (ident.Z.cast r)
| None => id
end
+ | ident.Z_cast2 (r1, r2)
+ => match relax_zrange r1, relax_zrange r2 with
+ | Some r1, Some r2
+ => AppIdent (ident.Z.cast2 (r1, r2))
+ | Some _, None | None, Some _ | None, None => id
+ end
| idc => AppIdent idc
end.
End relax.
@@ -6293,7 +6460,7 @@ Ltac solve_rone := solve_rop rone_correct.
Module PrintingNotations.
Export ident.
- Global Set Printing Width 100000.
+ (*Global Set Printing Width 100000.*)
Open Scope zrange_scope.
Notation "'uint256'"
:= (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : zrange_scope.
@@ -6303,6 +6470,8 @@ Module PrintingNotations.
:= (r[0 ~> 18446744073709551615]) : zrange_scope.
Notation "'uint32'"
:= (r[0 ~> 4294967295]) : zrange_scope.
+ Notation "'bool'"
+ := (r[0 ~> 1]%zrange) : zrange_scope.
Notation "ls [[ n ]]"
:= ((List.nth_default_concrete _ n @@ ls)%expr)
(at level 30, format "ls [[ n ]]") : expr_scope.
@@ -6310,12 +6479,16 @@ Module PrintingNotations.
:= ((ident.Z.cast range @@ (List.nth_default_concrete _ n @@ ls))%expr)
(format "( range )( ls [[ n ]] )") : expr_scope.
(*Notation "( range )( v )" := (ident.Z.cast range @@ v)%expr : expr_scope.*)
+ Notation "x *₂₅₆ y"
+ := (ident.Z.cast uint256 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope.
Notation "x *₁₂₈ y"
:= (ident.Z.cast uint128 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope.
Notation "x *₆₄ y"
:= (ident.Z.cast uint64 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope.
Notation "x *₃₂ y"
:= (ident.Z.cast uint32 @@ (ident.Z.mul @@ (x, y)))%expr (at level 40) : expr_scope.
+ Notation "x +₂₅₆ y"
+ := (ident.Z.cast uint256 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope.
Notation "x +₁₂₈ y"
:= (ident.Z.cast uint128 @@ (ident.Z.add @@ (x, y)))%expr (at level 50) : expr_scope.
Notation "x +₆₄ y"
@@ -6341,8 +6514,17 @@ Module PrintingNotations.
:= ((ident.Z.cast out_t @@ (ident.Z.land mask @@ v))%expr)
(format "( ( out_t )( v ) & mask )")
: expr_scope.
- Notation "v ₁" := (ident.fst @@ v)%expr (at level 10, format "v ₁") : expr_scope.
- Notation "v ₂" := (ident.snd @@ v)%expr (at level 10, format "v ₂") : expr_scope.
+
+ Notation "x" := (ident.Z.cast _ @@ Var x)%expr (only printing, at level 9) : expr_scope.
+ Notation "x" := (ident.Z.cast2 _ @@ Var x)%expr (only printing, at level 9) : expr_scope.
+ Notation "v ₁" := (ident.fst @@ Var v)%expr (at level 10, format "v ₁") : expr_scope.
+ Notation "v ₂" := (ident.snd @@ Var v)%expr (at level 10, format "v ₂") : expr_scope.
+ Notation "v ₁" := (ident.Z.cast _ @@ (ident.fst @@ Var v))%expr (at level 10, format "v ₁") : expr_scope.
+ Notation "v ₂" := (ident.Z.cast _ @@ (ident.snd @@ Var v))%expr (at level 10, format "v ₂") : expr_scope.
+ Notation "v ₁" := (ident.Z.cast _ @@ (ident.fst @@ (ident.Z.cast2 _ @@ Var v)))%expr (at level 10, format "v ₁") : expr_scope.
+ Notation "v ₂" := (ident.Z.cast _ @@ (ident.snd @@ (ident.Z.cast2 _ @@ Var v)))%expr (at level 10, format "v ₂") : expr_scope.
+
+
(*Notation "ls [[ n ]]" := (List.nth_default_concrete _ n @@ ls)%expr : expr_scope.
Notation "( range )( v )" := (ident.Z.cast range @@ v)%expr : expr_scope.
Notation "x *₁₂₈ y"
@@ -6408,13 +6590,12 @@ Module PrintingNotations.
(* TODO: come up with a better notation for arithmetic with carries
that still distinguishes it from arithmetic without carries? *)
Local Notation "'TwoPow256'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 (only parsing).
- (*Notation "'ADD_256'" := (add_get_carry_concrete _ _ uint256 _ TwoPow256) : nexpr_scope.
- Notation "'ADD_128'" := (add_get_carry_concrete _ _ uint128 _ TwoPow256) : nexpr_scope.
- Notation "'ADDC_256'" := (add_with_get_carry_concrete _ _ _ uint256 _ TwoPow256) : nexpr_scope.
- Notation "'SUB_256'" := (sub_get_borrow_concrete _ _ uint256 _ TwoPow256) : nexpr_scope.
- Notation "'ADDM'" := (add_modulo _ _ _ uint256) : nexpr_scope.
- Notation "'SELC'" := (zselect _ _ _ uint256) : nexpr_scope.
- Notation "'MUL_256'" := (mul uint128 uint128 uint256) : nexpr_scope.*)
+ Notation "'ADD_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_get_carry_concrete TwoPow256 @@ (x, y)))%expr : expr_scope.
+ Notation "'ADD_128' ( x , y )" := (ident.Z.cast2 (uint128, bool)%core @@ (ident.Z.add_get_carry_concrete TwoPow256 @@ (x, y)))%expr : expr_scope.
+ Notation "'ADDC_256' ( x , y , z )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.add_with_get_carry_concrete TwoPow256 @@ (x, y, z)))%expr : expr_scope.
+ Notation "'SUB_256' ( x , y )" := (ident.Z.cast2 (uint256, bool)%core @@ (ident.Z.sub_get_borrow_concrete TwoPow256 @@ (x, y)))%expr : expr_scope.
+ Notation "'ADDM' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.add_modulo @@ (x, y, z)))%expr : expr_scope.
+ Notation "'SELC' ( x , y , z )" := (ident.Z.cast uint256 @@ (ident.Z.zselect @@ (x, y, z)))%expr : expr_scope.
End PrintingNotations.
(*
@@ -7003,6 +7184,9 @@ Module MontgomeryReduction.
= montred' N R N' w w_half n lo_hi)
As montred_gen_correct.
Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
+ Module Export ReifyHints.
+ Global Hint Extern 1 (_ = montred' _ _ _ _ _ _ _) => simple apply montred_gen_correct : reify_gen_cache.
+ End ReifyHints.
Section rmontred.
Context (N R N' : Z)
@@ -7012,7 +7196,7 @@ Module MontgomeryReduction.
Let bound := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange.
Definition relax_zrange_of_machine_wordsize
- := relax_zrange_gen [machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize; 4 * machine_wordsize]%Z.
+ := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize; 4 * machine_wordsize]%Z.
Local Arguments relax_zrange_of_machine_wordsize / .
Let rw := rweight machine_wordsize.
@@ -7033,7 +7217,7 @@ Module MontgomeryReduction.
else res.
Notation BoundsPipeline_correct in_bounds out_bounds op
- := (fun (rop : Expr (type.reify_type_of op%function)) rv Hrop
+ := (fun rv (rop : Expr (type.reify_type_of op%function)) E Hrop HE
=> @Pipeline.BoundsPipeline_correct_trans
true (* DCE *)
relax_zrange
@@ -7043,7 +7227,7 @@ Module MontgomeryReduction.
in_bounds
out_bounds
op
- Hrop rv)
+ Hrop rv E HE)
(only parsing).
Definition rmontred_correct
@@ -7052,9 +7236,9 @@ Module MontgomeryReduction.
bound
(montred' N R N' (Interp rw) (Interp rw_half) 2).
- Notation type_of_strip_2arrow := ((fun s s' (d : Prop) (_ : s -> s' -> d) => d) _ _ _).
+ Notation type_of_strip_5arrow := ((fun (d : Prop) (_ : forall A B C D E, d) => d) _).
Definition rmontred_correctT rv : Prop
- := exists rop, type_of_strip_2arrow (@rmontred_correct rop rv).
+ := type_of_strip_5arrow (@rmontred_correct rv).
End rmontred.
End MontgomeryReduction.
@@ -7071,217 +7255,166 @@ Module Montgomery256.
Derive montred256
SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256)
As montred256_correct.
- Proof.
- eexists; eapply MontgomeryReduction.rmontred_correct with (machine_wordsize:=machine_wordsize).
- Time do_inline_cache_reify ltac:(fun _ => idtac).
- cbv [Pipeline.BoundsPipeline].
- set (k := PartialReduce _).
- cbv [CheckedPartialReduceWithBounds1].
- Time set (k' := PartialReduceWithBounds1 _ _).
- Time lazy in k'.
- Print fold_left.
- Time lazy; reflexivity.
- Time lazy -[Let_In k'].
- | (* Doing [lazy] is twice as slow as doing [lazy -[Let_In]; lazy].
- This is because the bounds pipeline does [dlet E : Expr := (reduced
- thing) in let b := extract_bounds E in ...]. If we allow [lazy] to
- unfold [Let_In] before it fully reduces the function (function,
- because [Expr := forall var, @expr var]), then there is no sharing
- between the partial reduction in bounds extraction and the partial
- reduction in the return value. So we force [lazy] to fully reduce
- the argument first, and only then permit [lazy] to inline it. This
- is slightly slower than doing bounds analysis in a non-PHOAS
- representation; we spend about 3%-5% of the overall time doing
- bounds extraction, and fully reducing the bounds extraction
- expression before plugging in arguments costs a bit more. However,
- it's still reasonably fast, and the code is much simpler when
- [Interp] always succeeds rather than returning [option]. *)
- lazy -[Let_In]; lazy; reflexivity ].
-
- Time solve_rmontred_nocache machine_wordsize.
- eapply MontgomeryReduction.rmontred_correct.
- cbv [MontgomeryReduction.rmontred].
- cbv [Pipeline.BoundsPipeline].
- cbv [CheckedPartialReduceWithBounds1].
- set (k := PartialReduceWithBounds1 _ _).
- Timeout 10 Time lazy in k.
- Time solve_rmontred(). Time Qed.
+ Proof. Time solve_rmontred machine_wordsize. Time Qed.
Import PrintingNotations.
- Open Scope nexpr_scope.
+ Open Scope expr_scope.
Set Printing Width 100000.
+
Print montred256.
- (*
- expr_let 3 := (uint128)(fst @@ x_1 >> 128) in
- expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in
- expr_let 5 := MUL_256 @@ (x_3, (79228162514264337593543950337)) in
- expr_let 7 := ((uint128)x_5 & 340282366920938463463374607431768211455) in
- expr_let 8 := MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) in
- expr_let 10 := ((uint128)x_8 & 340282366920938463463374607431768211455) in
- expr_let 11 := (uint128)(x_10 << 128) in
- expr_let 12 := (uint128)(x_7 << 128) in
- expr_let 17 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in
- expr_let 18 := ADD_128 @@ (x_11, x_12) in
- expr_let 19 := ADD_256 @@ (x_17, fst @@ x_18) in
- expr_let 43 := (uint128)(fst @@ x_19 >> 128) in
- expr_let 44 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in
- expr_let 45 := MUL_256 @@ (x_43, (79228162514264337593543950335)) in
- expr_let 46 := (uint128)(x_45 >> 128) in
- expr_let 47 := ((uint128)x_45 & 340282366920938463463374607431768211455) in
- expr_let 48 := MUL_256 @@ (x_44, (340282366841710300967557013911933812736)) in
- expr_let 49 := (uint128)(x_48 >> 128) in
- expr_let 50 := ((uint128)x_48 & 340282366920938463463374607431768211455) in
- expr_let 51 := (uint128)(x_50 << 128) in
- expr_let 52 := (uint128)(x_47 << 128) in
- expr_let 57 := MUL_256 @@ (x_44, (79228162514264337593543950335)) in
- expr_let 58 := ADD_128 @@ (x_51, x_52) in
- expr_let 59 := ADD_256 @@ (x_57, fst @@ x_58) in
- expr_let 60 := snd @@ x_59 +₁₂₈ snd @@ x_58 in
- expr_let 67 := MUL_256 @@ (x_43, (340282366841710300967557013911933812736)) in
- expr_let 69 := ADD_256 @@ (x_46, x_67) in
- expr_let 70 := ADD_256 @@ (x_49, fst @@ x_69) in
- expr_let 80 := ADD_256 @@ (x_60, fst @@ x_70) in
- expr_let 83 := ADD_256 @@ (fst @@ x_1, fst @@ x_59) in
- expr_let 84 := ADDC_256 @@ (snd @@ x_83, snd @@ x_1, fst @@ x_80) in
- expr_let 85 := SELC @@ (snd @@ x_84, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in
- expr_let 86 := fst @@ (SUB_256 @@ (fst @@ x_84, x_85)) in
- ADDM @@ (x_86, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951))
- : expr uint256
- *)
+ (*montred256 = fun var : type -> Type => λ v : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype,
+ expr_let v0 := (uint128)(v₁ >> 128) in
+ expr_let v1 := ((uint128)(v₁) & 340282366920938463463374607431768211455) in
+ expr_let v2 := ((uint128)(79228162514264337593543950337 *₂₅₆ v0) & 340282366920938463463374607431768211455) in
+ expr_let v3 := ((uint128)(340282366841710300986003757985643364352 *₂₅₆ v1) & 340282366920938463463374607431768211455) in
+ expr_let v4 := ADD_256 ((uint256)(v3 << 128), (uint256)(v2 << 128)) in
+ expr_let v5 := ADD_256 (79228162514264337593543950337 *₂₅₆ v1, v4₁) in
+ expr_let v6 := (uint128)(v5₁ >> 128) in
+ expr_let v7 := ((uint128)(v5₁) & 340282366920938463463374607431768211455) in
+ expr_let v8 := (uint128)(79228162514264337593543950335 *₂₅₆ v6 >> 128) in
+ expr_let v9 := ((uint128)(79228162514264337593543950335 *₂₅₆ v6) & 340282366920938463463374607431768211455) in
+ expr_let v10 := (uint128)(340282366841710300967557013911933812736 *₂₅₆ v7 >> 128) in
+ expr_let v11 := ((uint128)(340282366841710300967557013911933812736 *₂₅₆ v7) & 340282366920938463463374607431768211455) in
+ expr_let v12 := ADD_256 ((uint256)(v11 << 128), (uint256)(v9 << 128)) in
+ expr_let v13 := ADD_256 (79228162514264337593543950335 *₂₅₆ v7, v12₁) in
+ expr_let v14 := v13₂ +₁₂₈ v12₂ in
+ expr_let v15 := ADD_256 (v8, 340282366841710300967557013911933812736 *₂₅₆ v6) in
+ expr_let v16 := ADD_256 (v10, v15₁) in
+ expr_let v17 := ADD_256 (v14, v16₁) in
+ expr_let v18 := ADD_256 (v₁, v13₁) in
+ expr_let v19 := ADDC_256 (v18₂, v₂, v17₁) in
+ expr_let v20 := SELC (v19₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in
+ expr_let v21 := Z.cast uint256 @@ (fst @@ SUB_256 (v19₁, v20)) in
+ ADDM (v21, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ : Expr (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)
+*)
End Montgomery256.
(* Extra-specialized ad-hoc pretty-printing *)
Module Montgomery256PrintingNotations.
Export ident.
- Export BoundsAnalysis.ident.
- Export BoundsAnalysis.type.Notations.
- Export BoundsAnalysis.Indexed.expr.Notations.
- Export BoundsAnalysis.ident.Notations.
- Import BoundsAnalysis.type.
- Import BoundsAnalysis.Indexed.expr.
- Import BoundsAnalysis.ident.
- Open Scope btype_scope.
+ Open Scope expr_scope.
+ Open Scope ctype_scope.
Notation "'RegMod'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 115792089210356248762697446949407573530086143415290314195533631308867097853951; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope.
+ (AppIdent
+ (primitive 115792089210356248762697446949407573530086143415290314195533631308867097853951)
+ TT) (only printing, at level 9) : expr_scope.
Notation "'RegPinv'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 115792089210356248768974548684794254293921932838497980611635986753331132366849; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope.
+ (AppIdent
+ (primitive 115792089210356248768974548684794254293921932838497980611635986753331132366849)
+ TT) (only printing, at level 9) : expr_scope.
Notation "'RegZero'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 0; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope.
- Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : nexpr_scope.
+ (AppIdent
+ (primitive 0)
+ TT) (only printing, at level 9) : expr_scope.
+ Notation "'$R'" := 115792089237316195423570985008687907853269984665640564039457584007913129639936 : expr_scope.
Notation "'Lower128{RegMod}'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 79228162514264337593543950335; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope.
+ (AppIdent
+ (primitive 79228162514264337593543950335)
+ TT) (only printing, at level 9) : expr_scope.
Notation "'RegMod' '<<' '128'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 340282366841710300967557013911933812736; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9, format "'RegMod' '<<' '128'") : nexpr_scope.
+ (AppIdent
+ (primitive 340282366841710300967557013911933812736)
+ TT) (only printing, at level 9, format "'RegMod' '<<' '128'") : expr_scope.
Notation "'Lower128{RegPinv}'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 79228162514264337593543950337; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9) : nexpr_scope.
+ (AppIdent
+ (primitive 79228162514264337593543950337)
+ TT) (only printing, at level 9) : expr_scope.
Notation "'RegPinv' '>>' '128'" :=
- (BoundsAnalysis.Indexed.expr.AppIdent
- (primitive {| BoundsAnalysis.type.value := 340282366841710300986003757985643364352; BoundsAnalysis.type.value_bounded := _ |})
- BoundsAnalysis.Indexed.expr.TT) (only printing, at level 9, format "'RegPinv' '>>' '128'") : nexpr_scope.
+ (AppIdent
+ (primitive 340282366841710300986003757985643364352)
+ TT) (only printing, at level 9, format "'RegPinv' '>>' '128'") : expr_scope.
Notation "'uint256'"
- := (BoundsAnalysis.type.ZBounded 0 115792089237316195423570985008687907853269984665640564039457584007913129639935) : btype_scope.
+ := (r[0 ~> 115792089237316195423570985008687907853269984665640564039457584007913129639935]%zrange) : ctype_scope.
Notation "'uint128'"
- := (BoundsAnalysis.type.ZBounded 0 340282366920938463463374607431768211455) : btype_scope.
- Notation "$r n" := (BoundsAnalysis.Indexed.expr.Var _ n) (at level 10, format "$r n") : nexpr_scope.
- Notation "$r n '_lo'" := (fst @@ (BoundsAnalysis.Indexed.expr.Var (BoundsAnalysis.type.prod _ _) n))%nexpr (at level 10, format "$r n _lo") : nexpr_scope.
- Notation "$r n '_hi'" := (snd @@ (BoundsAnalysis.Indexed.expr.Var (BoundsAnalysis.type.prod _ _) n))%nexpr (at level 10, format "$r n _hi") : nexpr_scope.
- Notation "'c.Mul128x128(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := mul _ _ uint256 @@ (x, y) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' f" :=
- (expr_let n := shiftl _ _ count @@ (mul _ _ uint256 @@ (x, y)) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : nexpr_scope.
- Notation "'c.Add256(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := add_get_carry_concrete _ _ uint256 _ $R @@ (x, y) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Add128(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := add_get_carry_concrete _ _ uint128 _ $R @@ (x, y) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Add64(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := add _ _ uint128 @@ (x, y) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Addc(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := add_with_get_carry_concrete _ _ _ uint256 _ $R @@ (_, x, y) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Selc(' '$r' n ',' y ',' z ');' f" :=
- (expr_let n := zselect _ _ _ uint256 @@ (_, y, z) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$r' n ',' y ',' z ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Sub(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := fst @@ (sub_get_borrow_concrete _ _ uint256 _ $R @@ (x, y)) in
- f)%nexpr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope.
+ := (r[0 ~> 340282366920938463463374607431768211455]%zrange) : ctype_scope.
+ Notation "$ n" := (Var n) (at level 10, format "$ n") : expr_scope.
+ Notation "$ n" := (Z.cast _ @@ Var n) (at level 10, format "$ n") : expr_scope.
+ Notation "$ n '_lo'" := (fst @@ (Var n))%expr (at level 10, format "$ n _lo") : expr_scope.
+ Notation "$ n '_hi'" := (snd @@ (Var n))%expr (at level 10, format "$ n _hi") : expr_scope.
+ Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Var n)))%expr (at level 10, format "$ n _lo") : expr_scope.
+ Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Var n)))%expr (at level 10, format "$ n _hi") : expr_scope.
+ Notation "$ n '_lo'" := (Z.cast _ @@ (fst @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _lo") : expr_scope.
+ Notation "$ n '_hi'" := (Z.cast _ @@ (snd @@ (Z.cast2 _ @@ Var n)))%expr (at level 10, format "$ n _hi") : expr_scope.
+ Notation "'c.Mul128x128(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast uint256 @@ (Z.mul @@ (x, y)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' f" :=
+ (expr_let n := Z.cast _ @@ (Z.shiftl count @@ (Z.cast uint256 @@ (Z.mul @@ (x, y)))) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : expr_scope.
+ Notation "'c.Add256(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Add128(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast2 (uint128, _)%core @@ (Z.add_get_carry_concrete $R @@ (x, y)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Add64(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast uint128 @@ (Z.add @@ (x, y)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Addc(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast2 (uint256, _)%core @@ (Z.add_with_get_carry_concrete $R @@ (_, x, y)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Selc(' '$' n ',' y ',' z ');' f" :=
+ (expr_let n := Z.cast uint256 @@ (Z.zselect @@ (_, y, z)) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$' n ',' y ',' z ');' ']' '//' f") : expr_scope.
+ Notation "'c.Sub(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast uint256 @@ (fst @@ (Z.cast2 (uint256, _)%core @@ (Z.sub_get_borrow_concrete $R @@ (x, y)))) in
+ f)%expr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$' n ',' x ',' y ');' '//' f") : expr_scope.
Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" :=
- (add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope.
- Notation "'c.ShiftR(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.ShiftL(' '$r' n ',' x ',' y ');' f" :=
- (expr_let n := (shiftl _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope.
- Notation "'c.Lower128(' '$r' n ',' x ');' f" :=
- (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope.
+ (Z.cast uint256 @@ (Z.add_modulo @@ (x, y, z)))%expr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : expr_scope.
+ Notation "'c.ShiftR(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast _ @@ (Z.shiftr y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.ShiftL(' '$' n ',' x ',' y ');' f" :=
+ (expr_let n := Z.cast _ @@ (Z.shiftl y @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$' n ',' x ',' y ');' ']' '//' f") : expr_scope.
+ Notation "'c.Lower128(' '$' n ',' x ');' f" :=
+ (expr_let n := Z.cast _ @@ (Z.land 340282366920938463463374607431768211455 @@ x) in f)%expr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$' n ',' x ');' ']' '//' f") : expr_scope.
Notation "'Lower128'"
- := ((land uint256 uint128 340282366920938463463374607431768211455))
+ := (Z.cast uint128 @@ (Z.land 340282366920938463463374607431768211455))
(at level 10, only printing, format "Lower128")
- : nexpr_scope.
+ : expr_scope.
Notation "( v << count )"
- := ((shiftl _ _ count @@ v)%nexpr)
+ := (Z.cast _ @@ (Z.shiftl count @@ v)%expr)
(format "( v << count )")
- : nexpr_scope.
+ : expr_scope.
Notation "( x >> count )"
- := ((shiftr _ _ count @@ x)%nexpr)
+ := (Z.cast _ @@ (Z.shiftr count @@ x)%expr)
(format "( x >> count )")
- : nexpr_scope.
+ : expr_scope.
+ Notation "x * y"
+ := (Z.cast uint256 @@ (Z.mul @@ (x, y)))
+ : expr_scope.
End Montgomery256PrintingNotations.
Import Montgomery256PrintingNotations.
-Local Open Scope nexpr_scope.
+Local Open Scope expr_scope.
Print Montgomery256.montred256.
(*
-c.ShiftR($r3, $r1_lo, 128);
-c.Lower128($r4, $r1_lo);
-c.Mul128x128($r5, $r3, Lower128{RegPinv});
-c.Lower128($r7, $r5);
-c.Mul128x128($r8, $r4, RegPinv >> 128);
-c.Lower128($r10, $r8);
-c.ShiftL($r11, $r10, 128);
-c.ShiftL($r12, $r7, 128);
-c.Mul128x128($r17, $r4, Lower128{RegPinv});
-c.Add128($r18, $r11, $r12);
-c.Add256($r19, $r17, $r18_lo);
-c.ShiftR($r43, $r19_lo, 128);
-c.Lower128($r44, $r19_lo);
-c.Mul128x128($r45, $r43, Lower128{RegMod});
-c.ShiftR($r46, $r45, 128);
-c.Lower128($r47, $r45);
-c.Mul128x128($r48, $r44, RegMod << 128);
-c.ShiftR($r49, $r48, 128);
-c.Lower128($r50, $r48);
-c.ShiftL($r51, $r50, 128);
-c.ShiftL($r52, $r47, 128);
-c.Mul128x128($r57, $r44, Lower128{RegMod});
-c.Add128($r58, $r51, $r52);
-c.Add256($r59, $r57, $r58_lo);
-c.Add64($r60, $r59_hi, $r58_hi);
-c.Mul128x128($r67, $r43, RegMod << 128);
-c.Add256($r69, $r46, $r67);
-c.Add256($r70, $r49, $r69_lo);
-c.Add256($r80, $r60, $r70_lo);
-c.Add256($r83, $r1_lo, $r59_lo);
-c.Addc($r84, $r1_hi, $r80_lo);
-c.Selc($r85,RegZero, RegMod);
-c.Sub($r86, $r84_lo, $r85);
-c.AddM($ret, $r86, RegZero, RegMod);
- : expr uint256
+c.ShiftR($v0, $v_lo, 128);
+c.Lower128($v1, $v_lo);
+c.Lower128($v2, Lower128{RegPinv} * $v0);
+c.Lower128($v3, RegPinv >> 128 * $v1);
+c.Add256($v4, ($v3 << 128), ($v2 << 128));
+c.Add256($v5, Lower128{RegPinv} * $v1, $v4_lo);
+c.ShiftR($v6, $v5_lo, 128);
+c.Lower128($v7, $v5_lo);
+c.ShiftR($v8, Lower128{RegMod} * $v6, 128);
+c.Lower128($v9, Lower128{RegMod} * $v6);
+c.ShiftR($v10, RegMod << 128 * $v7, 128);
+c.Lower128($v11, RegMod << 128 * $v7);
+c.Add256($v12, ($v11 << 128), ($v9 << 128));
+c.Add256($v13, Lower128{RegMod} * $v7, $v12_lo);
+c.Add64($v14, $v13_hi, $v12_hi);
+c.Add256($v15, $v8, RegMod << 128 * $v6);
+c.Add256($v16, $v10, $v15_lo);
+c.Add256($v17, $v14, $v16_lo);
+c.Add256($v18, $v_lo, $v13_lo);
+c.Addc($v19, $v_hi, $v17_lo);
+c.Selc($v20,RegZero, RegMod);
+c.Sub($v21, $v19_lo, $v20);
+c.AddM($ret, $v21, RegZero, RegMod);
+ : Expr
+ (type.type_primitive type.Z * type.type_primitive type.Z ->
+ type.type_primitive type.Z)
*)