aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ModularArithmetic/ModularBaseSystemListZOperations.v5
-rw-r--r--src/Reflection/Z/Interpretations.v88
-rw-r--r--src/Reflection/Z/Interpretations/Relations.v9
-rw-r--r--src/Reflection/Z/Reify.v10
-rw-r--r--src/Reflection/Z/Syntax.v74
5 files changed, 170 insertions, 16 deletions
diff --git a/src/ModularArithmetic/ModularBaseSystemListZOperations.v b/src/ModularArithmetic/ModularBaseSystemListZOperations.v
index 1d863abbd..09a252a06 100644
--- a/src/ModularArithmetic/ModularBaseSystemListZOperations.v
+++ b/src/ModularArithmetic/ModularBaseSystemListZOperations.v
@@ -2,6 +2,7 @@
(** We separate these out so that we can depend on them in other files
without waiting for ModularBaseSystemList to build. *)
Require Import Coq.ZArith.ZArith.
+Require Import Crypto.Util.Tuple.
Definition cmovl (x y r1 r2 : Z) := if Z.leb x y then r1 else r2.
Definition cmovne (x y r1 r2 : Z) := if Z.eqb x y then r1 else r2.
@@ -10,3 +11,7 @@ Definition cmovne (x y r1 r2 : Z) := if Z.eqb x y then r1 else r2.
neg 1 = 2^64 - 1 (on 64-bit; 2^32-1 on 32-bit, etc.)
neg 0 = 0 *)
Definition neg (int_width : Z) (b : Z) := if Z.eqb b 1 then Z.ones int_width else 0%Z.
+
+(** TODO(jadep): Fill in this stub *)
+Axiom conditional_subtract_modulus
+ : forall (limb_count : nat) (int_width : Z) (modulus value : Tuple.tuple Z limb_count), Tuple.tuple Z limb_count.
diff --git a/src/Reflection/Z/Interpretations.v b/src/Reflection/Z/Interpretations.v
index 8371b3bef..1f0b95e7f 100644
--- a/src/Reflection/Z/Interpretations.v
+++ b/src/Reflection/Z/Interpretations.v
@@ -160,11 +160,14 @@ Module Word64.
Definition land : word64 -> word64 -> word64 := @wand _.
Definition lor : word64 -> word64 -> word64 := @wor _.
Definition neg : word64 -> word64 -> word64 (* TODO: FIXME? *)
- := fun x y => NToWord _ (Z.to_N (ModularBaseSystemListZOperations.neg (Z.of_N (wordToN x)) (Z.of_N (wordToN y)))).
+ := fun x y => ZToWord64 (ModularBaseSystemListZOperations.neg (word64ToZ x) (word64ToZ y)).
Definition cmovne : word64 -> word64 -> word64 -> word64 -> word64 (* TODO: FIXME? *)
- := fun x y z w => NToWord _ (Z.to_N (ModularBaseSystemListZOperations.cmovne (Z.of_N (wordToN x)) (Z.of_N (wordToN x)) (Z.of_N (wordToN z)) (Z.of_N (wordToN w)))).
+ := fun x y z w => ZToWord64 (ModularBaseSystemListZOperations.cmovne (word64ToZ x) (word64ToZ x) (word64ToZ z) (word64ToZ w)).
Definition cmovle : word64 -> word64 -> word64 -> word64 -> word64 (* TODO: FIXME? *)
- := fun x y z w => NToWord _ (Z.to_N (ModularBaseSystemListZOperations.cmovl (Z.of_N (wordToN x)) (Z.of_N (wordToN x)) (Z.of_N (wordToN z)) (Z.of_N (wordToN w)))).
+ := fun x y z w => ZToWord64 (ModularBaseSystemListZOperations.cmovl (word64ToZ x) (word64ToZ x) (word64ToZ z) (word64ToZ w)).
+ Definition conditional_subtract (pred_limb_count : nat) : word64 -> Tuple.tuple word64 (S pred_limb_count) -> Tuple.tuple word64 (S pred_limb_count) -> Tuple.tuple word64 (S pred_limb_count)
+ := fun x y z => Tuple.map ZToWord64 (@ModularBaseSystemListZOperations.conditional_subtract_modulus
+ (S pred_limb_count) (word64ToZ x) (Tuple.map word64ToZ y) (Tuple.map word64ToZ z)).
Infix "+" := add : word64_scope.
Infix "-" := sub : word64_scope.
Infix "*" := mul : word64_scope.
@@ -214,6 +217,9 @@ Module Word64.
| Neg => fun xy => neg (fst xy) (snd xy)
| Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w
| Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle x y z w
+ | ConditionalSubtract pred_n
+ => fun xyz => let '(x, y, z) := eta3 xyz in
+ flat_interp_untuple' (T:=Tbase TZ) (@conditional_subtract pred_n x (flat_interp_tuple y) (flat_interp_tuple z))
end%word64.
Definition of_Z ty : Z.interp_base_type ty -> interp_base_type ty
@@ -315,6 +321,39 @@ Module ZBounds.
Definition cmovle' (r1 r2 : bounds) : bounds
:= let (lr1, ur1) := r1 in let (lr2, ur2) := r2 in {| lower := Z.min lr1 lr2 ; upper := Z.max ur1 ur2 |}.
Definition cmovle (x y r1 r2 : t) : t := t_map2 cmovle' r1 r2.
+ (** TODO(jadep): Check that this is correct; it computes the bounds,
+ conditional on the assumption that the entire calculation is
+ valid. Currently, it says that each limb is upper-bounded by
+ either the original value less the modulus, or by the smaller of
+ the original value and the modulus (in the case that the
+ subtraction is negative). Feel free to substitute any other
+ bounds you'd like here. *)
+ Definition conditional_subtract' (pred_n : nat) (int_width : bounds)
+ (modulus value : Tuple.tuple bounds (S pred_n))
+ : Tuple.tuple bounds (S pred_n)
+ := Tuple.map2
+ (fun modulus_bounds value_bounds : bounds
+ => let (ml, mu) := modulus_bounds in
+ let (vl, vu) := value_bounds in
+ {| lower := 0 ; upper := Z.max (Z.min vu mu) (vu - ml) |})
+ modulus value.
+ (** TODO(jadep): Fill me in. This should check that the modulus and
+ value fit within int_width, that the modulus is of the right
+ form, and that the value is small enough. If not, it should
+ [None]; otherwise, it should delegate to
+ [conditional_subtract']. *)
+ Axiom conditional_subtract_o
+ : forall (pred_n : nat) (int_width : bounds)
+ (modulus value : Tuple.tuple bounds (S pred_n)), option (Tuple.tuple bounds (S pred_n)).
+ Definition conditional_subtract (pred_n : nat) (int_width : t)
+ (modulus value : Tuple.tuple t (S pred_n))
+ : Tuple.tuple t (S pred_n)
+ := Tuple.push_option
+ match int_width, Tuple.lift_option modulus, Tuple.lift_option value with
+ | Some int_width, Some modulus, Some value
+ => conditional_subtract_o pred_n int_width modulus value
+ | _, _, _ => None
+ end.
Module Export Notations.
Delimit Scope bounds_scope with bounds.
@@ -341,6 +380,9 @@ Module ZBounds.
| Neg => fun xy => neg (fst xy) (snd xy)
| Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w
| Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle x y z w
+ | ConditionalSubtract pred_n
+ => fun xyz => let '(x, y, z) := eta3 xyz in
+ flat_interp_untuple' (T:=Tbase TZ) (@conditional_subtract pred_n x (flat_interp_tuple y) (flat_interp_tuple z))
end%bounds.
Definition of_word64 ty : Word64.interp_base_type ty -> interp_base_type ty
@@ -559,6 +601,38 @@ Module BoundedWord64.
Definition cmovle : t -> t -> t -> t -> t.
Proof. build_4op Word64.cmovle ZBounds.cmovle; abstract t_start. Defined.
+ Definition conditional_subtract (pred_n : nat) (int_width : t)
+ (modulus val : Tuple.tuple t (S pred_n))
+ : Tuple.tuple t (S pred_n).
+ Proof.
+ refine (match int_width, Tuple.lift_option modulus, Tuple.lift_option val with
+ | Some int_width, Some modulus, Some val
+ => let boundsv := Tuple.push_option
+ (ZBounds.conditional_subtract_o
+ pred_n (BoundedWordToBounds int_width)
+ (Tuple.map BoundedWordToBounds modulus)
+ (Tuple.map BoundedWordToBounds val)) in
+ let wordv := (Word64.conditional_subtract
+ pred_n (value int_width)
+ (Tuple.map value modulus)
+ (Tuple.map value val)) in
+ let ret_val := Tuple.map2
+ (fun bs val
+ => option_map (fun bs' : ZBounds.bounds
+ => let (l, u) := bs' in
+ (l, val, u)) bs)
+ boundsv wordv in
+ _
+ | _, _, _ => Tuple.push_option None
+ end).
+ (** TODO(jadep): Use the bounds lemma here to prove that if each
+ component of [ret_val] is [Some (l, v, u)], then we can fill
+ in [pf] and return the tuple of [{| lower := l ; value := v ;
+ upper := u ; in_bounds := pf |}]. *)
+ admit.
+ Defined.
+
+
Local Notation binop_correct op opW opB :=
(forall x y v, op (Some x) (Some y) = Some v -> value v = opW (value x) (value y)
/\ BoundedWordToBounds v = opB (BoundedWordToBounds x) (BoundedWordToBounds y))
@@ -618,6 +692,9 @@ Module BoundedWord64.
Proof. invert_t. Qed.
Definition invert_cmovle : op4_correct cmovle Word64.cmovle (fun _ _ => ZBounds.cmovle').
Proof. invert_t. Qed.
+ (** TODO(jadep): Fill me in *)
+ Definition invert_conditional_subtract : False.
+ Proof. Admitted.
Module Export Notations.
Delimit Scope bounded_word_scope with bounded_word.
@@ -642,7 +719,10 @@ Module BoundedWord64.
| Neg => fun xy => neg (fst xy) (snd xy)
| Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w
| Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovle x y z w
- end%bounded_word.
+ | ConditionalSubtract pred_n
+ => fun xyz => let '(x, y, z) := eta3 xyz in
+ flat_interp_untuple' (T:=Tbase TZ) (@conditional_subtract pred_n x (flat_interp_tuple y) (flat_interp_tuple z))
+ end%bounded_word.
End BoundedWord64.
Module ZBoundsTuple.
diff --git a/src/Reflection/Z/Interpretations/Relations.v b/src/Reflection/Z/Interpretations/Relations.v
index 2ec29a05e..9c576dcf9 100644
--- a/src/Reflection/Z/Interpretations/Relations.v
+++ b/src/Reflection/Z/Interpretations/Relations.v
@@ -122,6 +122,9 @@ Local Ltac related_word64_op_t_step :=
end ].
Local Ltac related_word64_op_t := repeat related_word64_op_t_step.
+Axiom proof_admitted : False.
+Tactic Notation "admit" := abstract case proof_admitted.
+
Lemma related_word64_op : related_op related_word64 (@BoundedWord64.interp_op) (@Word64.interp_op).
Proof.
let op := fresh in intros ?? op; destruct op; simpl.
@@ -135,11 +138,9 @@ Proof.
{ related_word64_op_t. }
{ related_word64_op_t. }
{ related_word64_op_t. }
+ { related_word64_op_t; admit. (** TODO(jadep): Fill me in *) }
Qed.
-Axiom proof_admitted : False.
-Tactic Notation "admit" := abstract case proof_admitted.
-
Lemma related_Z_op : related_op related_Z (@BoundedWord64.interp_op) (@Z.interp_op).
Proof.
let op := fresh in intros ?? op; destruct op; simpl.
@@ -153,6 +154,7 @@ Proof.
{ related_word64_op_t; admit. }
{ related_word64_op_t; admit. }
{ related_word64_op_t; admit. }
+ { related_word64_op_t; admit. (** TODO(jadep or jgross): Fill me in *) }
Qed.
Local Arguments ZBounds.SmartBuildBounds _ _ / .
@@ -169,6 +171,7 @@ Proof.
{ related_word64_op_t. }
{ admit; related_word64_op_t. }
{ admit; related_word64_op_t. }
+ { admit; related_word64_op_t. (** TODO(jadep or jgross): Fill me in *) }
Qed.
Create HintDb interp_related discriminated.
diff --git a/src/Reflection/Z/Reify.v b/src/Reflection/Z/Reify.v
index 451e522d0..6734d7d01 100644
--- a/src/Reflection/Z/Reify.v
+++ b/src/Reflection/Z/Reify.v
@@ -21,6 +21,16 @@ Ltac base_reify_op op op_head extra ::=
| @ModularBaseSystemListZOperations.neg => constr:(reify_op op op_head 2 Neg)
| @ModularBaseSystemListZOperations.cmovne => constr:(reify_op op op_head 4 Cmovne)
| @ModularBaseSystemListZOperations.cmovl => constr:(reify_op op op_head 4 Cmovle)
+ | @ModularBaseSystemListZOperations.conditional_subtract_modulus
+ => lazymatch extra with
+ | @ModularBaseSystemListZOperations.conditional_subtract_modulus ?limb_count _ _ _
+ => lazymatch (eval compute in limb_count) with
+ | 0 => fail 1 "Cannot handle empty limb-list in reifying conditional_subtract_modulus"
+ | S ?pred_limb_count => constr:(reify_op op op_head 3 (ConditionalSubtract pred_limb_count))
+ | ?climb_count => fail 1 "Cannot handle non-ground length of limb-list in reifying conditional_subtract_modulus" "(" limb_count "which computes to" climb_count ")"
+ end
+ | _ => fail 100 "Anomaly: In Reflection.Z.base_reify_op: head is conditional_subtract_modulus but body is wrong:" extra
+ end
end.
Ltac base_reify_type T ::=
lazymatch T with
diff --git a/src/Reflection/Z/Syntax.v b/src/Reflection/Z/Syntax.v
index 4b23cc274..d1274a85f 100644
--- a/src/Reflection/Z/Syntax.v
+++ b/src/Reflection/Z/Syntax.v
@@ -4,6 +4,8 @@ Require Import Crypto.Reflection.Syntax.
Require Import Crypto.ModularArithmetic.ModularBaseSystemListZOperations.
Require Import Crypto.Util.Equality.
Require Import Crypto.Util.ZUtil.
+Require Import Crypto.Util.HProp.
+Require Import Crypto.Util.Decidable.
Require Import Crypto.Util.PartiallyReifiedProp.
Export Syntax.Notations.
@@ -16,11 +18,19 @@ Definition interp_base_type (v : base_type) : Type :=
| TZ => Z
end.
+Global Instance dec_eq_base_type : DecidableRel (@eq base_type)
+ := base_type_eq_dec.
+Global Instance dec_eq_flat_type : DecidableRel (@eq (flat_type base_type)) := _.
+Global Instance dec_eq_type : DecidableRel (@eq (type base_type)) := _.
+
Local Notation tZ := (Tbase TZ).
Local Notation eta x := (fst x, snd x).
Local Notation eta3 x := (eta (fst x), snd x).
Local Notation eta4 x := (eta3 (fst x), snd x).
+Axiom proof_admitted : False.
+Local Notation admit := (match proof_admitted with end).
+
Inductive op : flat_type base_type -> flat_type base_type -> Type :=
| Add : op (tZ * tZ) tZ
| Sub : op (tZ * tZ) tZ
@@ -31,7 +41,12 @@ Inductive op : flat_type base_type -> flat_type base_type -> Type :=
| Lor : op (tZ * tZ) tZ
| Neg : op (tZ * tZ) tZ
| Cmovne : op (tZ * tZ * tZ * tZ) tZ
-| Cmovle : op (tZ * tZ * tZ * tZ) tZ.
+| Cmovle : op (tZ * tZ * tZ * tZ) tZ
+| ConditionalSubtract (pred_limb_count : nat)
+ : op (tZ (* int_width *)
+ * Syntax.tuple tZ (S pred_limb_count) (* modulus *)
+ * Syntax.tuple tZ (S pred_limb_count) (* value *))
+ (Syntax.tuple tZ (S pred_limb_count)).
Definition interp_op src dst (f : op src dst) : interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst
:= match f in op src dst return interp_flat_type interp_base_type src -> interp_flat_type interp_base_type dst with
@@ -45,6 +60,9 @@ Definition interp_op src dst (f : op src dst) : interp_flat_type interp_base_typ
| Neg => fun xy => ModularBaseSystemListZOperations.neg (fst xy) (snd xy)
| Cmovne => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovne x y z w
| Cmovle => fun xyzw => let '(x, y, z, w) := eta4 xyzw in cmovl x y z w
+ | ConditionalSubtract pred_n
+ => fun xyz => let '(x, y, z) := eta3 xyz in
+ flat_interp_untuple' (T:=tZ) (@ModularBaseSystemListZOperations.conditional_subtract_modulus (S pred_n) x (flat_interp_tuple y) (flat_interp_tuple z))
end%Z.
Definition base_type_eq_semidec_transparent (t1 t2 : base_type)
@@ -57,7 +75,7 @@ Proof.
unfold base_type_eq_semidec_transparent; congruence.
Qed.
-Definition op_beq t1 tR (f g : op t1 tR) : reified_Prop
+Definition op_beq_hetero {t1 tR t1' tR'} (f : op t1 tR) (g : op t1' tR') : reified_Prop
:= match f, g return bool with
| Add, Add => true
| Add, _ => false
@@ -79,15 +97,53 @@ Definition op_beq t1 tR (f g : op t1 tR) : reified_Prop
| Cmovne, _ => false
| Cmovle, Cmovle => true
| Cmovle, _ => false
+ | ConditionalSubtract n, ConditionalSubtract m => NatUtil.nat_beq n m
+ | ConditionalSubtract _, _ => false
end.
+Definition op_beq t1 tR (f g : op t1 tR) : reified_Prop
+ := Eval cbv [op_beq_hetero] in op_beq_hetero f g.
+
+Definition op_beq_hetero_type_eq {t1 tR t1' tR'} f g : to_prop (@op_beq_hetero t1 tR t1' tR' f g) -> t1 = t1' /\ tR = tR'.
+Proof.
+ destruct f, g; simpl; try solve [ repeat constructor | intros [] ].
+ unfold op_beq_hetero; simpl.
+ match goal with
+ | [ |- context[to_prop (reified_Prop_of_bool ?x)] ]
+ => destruct (Sumbool.sumbool_of_bool x) as [P|P]
+ end.
+ { apply NatUtil.internal_nat_dec_bl in P; subst; repeat constructor. }
+ { intro H'; exfalso; rewrite P in H'; exact H'. }
+Defined.
+
+Definition op_beq_hetero_type_eqs {t1 tR t1' tR'} f g : to_prop (@op_beq_hetero t1 tR t1' tR' f g) -> t1 = t1'
+ := fun H => let (p, q) := @op_beq_hetero_type_eq t1 tR t1' tR' f g H in p.
+Definition op_beq_hetero_type_eqd {t1 tR t1' tR'} f g : to_prop (@op_beq_hetero t1 tR t1' tR' f g) -> tR = tR'
+ := fun H => let (p, q) := @op_beq_hetero_type_eq t1 tR t1' tR' f g H in q.
+
+Definition op_beq_hetero_eq {t1 tR t1' tR'} f g
+ : forall pf : to_prop (@op_beq_hetero t1 tR t1' tR' f g),
+ eq_rect
+ _ (fun src => op src tR')
+ (eq_rect _ (fun dst => op t1 dst) f _ (op_beq_hetero_type_eqd f g pf))
+ _ (op_beq_hetero_type_eqs f g pf)
+ = g.
+Proof.
+ destruct f, g; simpl; try solve [ reflexivity | intros [] ].
+ { unfold op_beq_hetero, op_beq_hetero_type_eqs, op_beq_hetero_type_eqd; simpl.
+ intro pf; edestruct Sumbool.sumbool_of_bool.
+ { simpl; edestruct NatUtil.internal_nat_dec_bl; reflexivity. }
+ { match goal with
+ | [ |- context[False_ind _ ?pf] ]
+ => case pf
+ end. } }
+Qed.
+
Lemma op_beq_bl : forall t1 tR x y, to_prop (op_beq t1 tR x y) -> x = y.
Proof.
- intros ?? x; destruct x;
- intro y;
- refine match y with
- | Add => _
- | _ => _
- end;
- compute; try (reflexivity || trivial || (intros; exfalso; assumption)).
+ intros ?? f g H.
+ pose proof (op_beq_hetero_eq f g H) as H'; subst.
+ generalize dependent (op_beq_hetero_type_eqd f g H).
+ generalize dependent (op_beq_hetero_type_eqs f g H).
+ intros; eliminate_hprop_eq; simpl in *; assumption.
Qed.