From 0499ec53ad54f6e93fdf8fdc55b22d1f2ec53263 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 27 Mar 2018 17:28:57 -0400 Subject: Allow passing in optional bounds to the pipeline --- src/Experiments/SimplyTypedArithmetic.v | 165 +++++++++++++++++++------------- 1 file changed, 97 insertions(+), 68 deletions(-) (limited to 'src') diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index f874a3396..06f856e43 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -1090,8 +1090,10 @@ Qed. Module Ring. Local Notation is_bounded_by0 r v := ((lower r <=? v) && (v <=? upper r)). + Local Notation is_bounded_by0o r + := (match r with Some r' => fun v => is_bounded_by0 r' v | None => fun _ => true end). Local Notation is_bounded_by bounds ls - := (fold_andb_map (fun r v => is_bounded_by0 r v) bounds ls). + := (fold_andb_map (fun r v => is_bounded_by0o r v) bounds ls). Local Notation is_bounded_by2 bounds ls := (let '(a, b) := ls in andb (is_bounded_by bounds a) (is_bounded_by bounds b)). @@ -1107,9 +1109,9 @@ Module Ring. (n : nat) (s : Z) (c : list (Z * Z)) - (tight_bounds : list zrange) + (tight_bounds : list (option zrange)) (length_tight_bounds : length tight_bounds = n) - (loose_bounds : list zrange) + (loose_bounds : list (option zrange)) (length_loose_bounds : length loose_bounds = n). Local Notation weight := (weight limbwidth_num limbwidth_den). Local Notation eval := (Positional.eval weight n). @@ -3983,15 +3985,33 @@ Module Compilers. | type.list A => fun ls1 ls2 => match ls1 with - | Datatypes.None => false + | Datatypes.None => true | Datatypes.Some ls1 => fold_andb_map (@is_bounded_by A) ls1 ls2 end end. - Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val - (Htight : @is_tighter_than t r1 (Some r2) = true) + Lemma is_bounded_by_Some {t} r val + : is_bounded_by (@Some t r) val = type.is_bounded_by r val. + Proof. + induction t; + repeat first [ reflexivity + | progress cbn in * + | progress destruct_head'_prod + | progress destruct_head' type.primitive + | match goal with H : _ |- _ => rewrite H end ]. + { lazymatch goal with + | [ r : list (type.interp t), val : list (Compilers.type.interp t) |- _ ] + => revert r val IHt + end; intros r val; revert r val. + induction r, val; cbn; auto with nocore; try congruence; []. + intro H'; rewrite H', IHr by auto. + reflexivity. } + Qed. + + Lemma is_tighter_than_is_bounded_by {t} r1 r2 val + (Htight : @is_tighter_than t r1 r2 = true) (Hbounds : is_bounded_by r1 val = true) - : type.is_bounded_by r2 val = true. + : is_bounded_by r2 val = true. Proof. induction t; repeat first [ progress destruct_head'_prod @@ -4007,12 +4027,21 @@ Module Compilers. | Z.ltb_to_lt; omega | rewrite @fold_andb_map_map in * ]. { lazymatch goal with - | [ r1 : list (interp t), r2 : list (type.interp t), val : list (Compilers.type.interp t) |- _ ] + | [ r1 : list (interp t), r2 : list (interp t), val : list (Compilers.type.interp t) |- _ ] => revert r1 r2 val Htight Hbounds IHt end; intros r1 r2 val; revert r1 r2 val. induction r1, r2, val; cbn; auto with nocore; try congruence; []. rewrite !Bool.andb_true_iff; intros; destruct_head'_and; split; eauto with nocore. } Qed. + + Lemma is_tighter_than_Some_is_bounded_by {t} r1 r2 val + (Htight : @is_tighter_than t r1 (Some r2) = true) + (Hbounds : is_bounded_by r1 val = true) + : type.is_bounded_by r2 val = true. + Proof. + rewrite <- is_bounded_by_Some. + eapply is_tighter_than_is_bounded_by; eassumption. + Qed. End option. End type. @@ -5024,13 +5053,13 @@ Module Compilers. := partial.expr.reify (@partial_evaluate' t e). Definition partial_evaluate_with_bounds1' {s d} (e : @expr (partial.value var) (s -> d)) - (b : ZRange.type.interp s) + (b : ZRange.type.option.interp s) : partial.value var (s -> d) := fun x : partial.value var s - => partial_evaluate' e (partial.bounds.extend_with_bounds b x). + => partial_evaluate' e (partial.bounds.extend_with_obounds b x). Definition partial_evaluate_with_bounds1 {s d} (e : @expr (partial.value var) (s -> d)) - (b : ZRange.type.interp s) + (b : ZRange.type.option.interp s) := partial.expr.reify (@partial_evaluate_with_bounds1' s d e b). End partial_evaluate. @@ -5087,36 +5116,36 @@ Module Compilers. End RelaxZRange. Definition PartialEvaluateWithBounds1 - {s d} (e : Expr (s -> d)) (b : ZRange.type.interp s) + {s d} (e : Expr (s -> d)) (b : ZRange.type.option.interp s) : Expr (s -> d) := fun var => @partial_evaluate_with_bounds1 true var s d (e _) b. Definition CheckPartialEvaluateWithBounds1 (relax_zrange : zrange -> option zrange) {s d} (E : Expr (s -> d)) - (b_in : ZRange.type.interp s) - (b_out : ZRange.type.interp d) + (b_in : ZRange.type.option.interp s) + (b_out : ZRange.type.option.interp d) : Expr (s -> d) + ZRange.type.option.interp d - := let b_computed := partial.bounds.expr.Extract E (ZRange.type.option.Some b_in) in - if ZRange.type.option.is_tighter_than b_computed (ZRange.type.option.Some b_out) + := let b_computed := partial.bounds.expr.Extract E b_in in + if ZRange.type.option.is_tighter_than b_computed b_out then @inl (Expr (s -> d)) _ (RelaxZRange.expr.Relax relax_zrange E) else @inr _ (ZRange.type.option.interp d) b_computed. Definition CheckPartialEvaluateWithBounds0 (relax_zrange : zrange -> option zrange) {t} (E : Expr t) - (b_out : ZRange.type.interp t) + (b_out : ZRange.type.option.interp t) : Expr t + ZRange.type.option.interp t := let b_computed := partial.bounds.expr.Extract E in - if ZRange.type.option.is_tighter_than b_computed (ZRange.type.option.Some b_out) + if ZRange.type.option.is_tighter_than b_computed b_out then @inl (Expr t) _ (RelaxZRange.expr.Relax relax_zrange E) else @inr _ (ZRange.type.option.interp t) b_computed. Definition CheckedPartialEvaluateWithBounds1 (relax_zrange : zrange -> option zrange) {s d} (e : Expr (s -> d)) - (b_in : ZRange.type.interp s) - (b_out : ZRange.type.interp d) + (b_in : ZRange.type.option.interp s) + (b_out : ZRange.type.option.interp d) : Expr (s -> d) + ZRange.type.option.interp d := dlet_nd E := PartialEvaluateWithBounds1 e b_in in CheckPartialEvaluateWithBounds1 relax_zrange E b_in b_out. @@ -5124,7 +5153,7 @@ Module Compilers. Definition CheckedPartialEvaluateWithBounds0 (relax_zrange : zrange -> option zrange) {t} (e : Expr t) - (b_out : ZRange.type.interp t) + (b_out : ZRange.type.option.interp t) : Expr t + ZRange.type.option.interp t := dlet_nd E := PartialEvaluate true e in CheckPartialEvaluateWithBounds0 relax_zrange E b_out. @@ -5138,21 +5167,21 @@ Module Compilers. -> relax_zrange r = Some r' -> is_tighter_than_bool z r' = true) {s d} (e : Expr (s -> d)) - (b_in : ZRange.type.interp s) - (b_out : ZRange.type.interp d) + (b_in : ZRange.type.option.interp s) + (b_out : ZRange.type.option.interp d) E (HE : PartialEvaluateWithBounds1 e b_in = E) rv (Hrv : CheckPartialEvaluateWithBounds1 relax_zrange E b_in b_out = inl rv) : forall arg - (Harg : ZRange.type.is_bounded_by b_in arg = true), + (Harg : ZRange.type.option.is_bounded_by b_in arg = true), Interp rv arg = Interp e arg - /\ ZRange.type.is_bounded_by b_out (Interp rv arg) = true. + /\ ZRange.type.option.is_bounded_by b_out (Interp rv arg) = true. Proof. cbv [CheckedPartialEvaluateWithBounds1 CheckPartialEvaluateWithBounds1 Let_In] in *; break_innermost_match_hyps; inversion_sum; subst. intros arg Harg. split. { exact admit. (* correctness of interp *) } - { eapply ZRange.type.option.is_tighter_than_Some_is_bounded_by; [ eassumption | ]. + { eapply ZRange.type.option.is_tighter_than_is_bounded_by; [ eassumption | ]. cbv [expr.Interp]. revert Harg. exact admit. (* boundedness *) } @@ -5164,17 +5193,17 @@ Module Compilers. -> relax_zrange r = Some r' -> is_tighter_than_bool z r' = true) {t} (e : Expr t) - (b_out : ZRange.type.interp t) + (b_out : ZRange.type.option.interp t) E (HE : PartialEvaluate true e = E) rv (Hrv : CheckPartialEvaluateWithBounds0 relax_zrange E b_out = inl rv) : Interp rv = Interp e - /\ ZRange.type.is_bounded_by b_out (Interp rv) = true. + /\ ZRange.type.option.is_bounded_by b_out (Interp rv) = true. Proof. cbv [CheckedPartialEvaluateWithBounds0 CheckPartialEvaluateWithBounds0 Let_In] in *; break_innermost_match_hyps; inversion_sum; subst. split. { exact admit. (* correctness of interp *) } - { eapply ZRange.type.option.is_tighter_than_Some_is_bounded_by; [ eassumption | ]. + { eapply ZRange.type.option.is_tighter_than_is_bounded_by; [ eassumption | ]. cbv [expr.Interp]. exact admit. (* boundedness *) } Qed. @@ -5610,7 +5639,7 @@ Module test2. expr_let x1 := (Var x0 * Var x0) in (Var x1, Var x1))%expr) => idtac end. - pose (PartialEvaluateWithBounds1 E' r[0~>10]%zrange) as E''. + pose (PartialEvaluateWithBounds1 E' (Some r[0~>10]%zrange)) as E''. lazy in E''. lazymatch (eval cbv delta [E''] in E'') with | (fun var : type -> Type => @@ -5646,7 +5675,7 @@ Module test3. Var x3 * Var x3)%expr) => idtac end. - pose (PartialEvaluateWithBounds1 E' r[0~>10]%zrange) as E'''. + pose (PartialEvaluateWithBounds1 E' (Some r[0~>10]%zrange)) as E'''. lazy in E'''. lazymatch (eval cbv delta [E'''] in E''') with | (fun var : type -> Type => @@ -5674,7 +5703,7 @@ Module test4. pose (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'. lazy in E'. clear E. - pose (PartialEvaluateWithBounds1 E' ([r[0~>10]%zrange],[r[0~>10]%zrange])) as E''. + pose (PartialEvaluateWithBounds1 E' (Some [Some r[0~>10]%zrange],Some [Some r[0~>10]%zrange])) as E''. lazy in E''. lazymatch (eval cbv delta [E''] in E'') with | (fun var : type -> Type => @@ -6034,7 +6063,7 @@ Module Pipeline. := let E := CheckPartialEvaluateWithBounds1 relax_zrange E arg_bounds out_bounds in let E := match E with | inl v => Success v - | inr b => Error (Computed_bounds_are_not_tight_enough b (ZRange.type.option.Some out_bounds)) + | inr b => Error (Computed_bounds_are_not_tight_enough b out_bounds) end in E. @@ -6068,8 +6097,8 @@ Module Pipeline. (HE : BoundsPipelineNoCheck (*with_dead_code_elimination*) with_subst01 e arg_bounds = Success E) (Hrv : CheckBoundsPipeline relax_zrange E arg_bounds out_bounds = Success rv) : forall arg - (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), - ZRange.type.is_bounded_by out_bounds (Interp rv arg) = true + (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), + ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true /\ Interp rv arg = Interp e arg. Proof. cbv [BoundsPipeline BoundsPipelineNoCheck CheckBoundsPipeline Let_In] in *; @@ -6094,8 +6123,8 @@ Module Pipeline. (InterpE : type.interp s -> type.interp d) (rv : Expr (s -> d)) := forall arg - (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), - ZRange.type.is_bounded_by out_bounds (Interp rv arg) = true + (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), + ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true /\ Interp rv arg = InterpE arg. Lemma BoundsPipeline_correct_trans @@ -6111,7 +6140,7 @@ Module Pipeline. (InterpE : type.interp s -> type.interp d) (InterpE_correct : forall arg - (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), + (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), Interp e arg = InterpE arg) rv E (HE : BoundsPipelineNoCheck (*with_dead_code_elimination*) with_subst01 e arg_bounds = Success E) @@ -6151,8 +6180,8 @@ Module Pipeline. rv (Hrv : BoundsPipeline_full (*with_dead_code_elimination*) with_subst01 relax_zrange E arg_bounds out_bounds = Success rv) : forall arg - (Harg : ZRange.type.is_bounded_by arg_bounds arg = true), - ZRange.type.is_bounded_by out_bounds (Interp rv arg) = true + (Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true), + ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true /\ Interp rv arg = for_reification.Interp E arg. Proof. cbv [BoundsPipeline_full BoundsPipeline] in *. @@ -6197,7 +6226,7 @@ Module Pipeline. := let E := CheckPartialEvaluateWithBounds0 relax_zrange E bounds in let E := match E with | inl v => Success v - | inr b => Error (Computed_bounds_are_not_tight_enough b (ZRange.type.option.Some bounds)) + | inr b => Error (Computed_bounds_are_not_tight_enough b bounds) end in E. @@ -6228,7 +6257,7 @@ Module Pipeline. E (HE : BoundsPipelineConstNoCheck (*with_dead_code_elimination*) with_subst01 e = Success E) (Hrv : CheckBoundsPipelineConst relax_zrange E bounds = Success rv) - : ZRange.type.is_bounded_by bounds (Interp rv) = true + : ZRange.type.option.is_bounded_by bounds (Interp rv) = true /\ Interp rv = Interp e. Proof. cbv [BoundsPipelineConst CheckBoundsPipelineConst BoundsPipelineConstNoCheck Let_In] in *; @@ -6251,7 +6280,7 @@ Module Pipeline. out_bounds (InterpE : type.interp t) (rv : Expr t) - := ZRange.type.is_bounded_by out_bounds (Interp rv) = true + := ZRange.type.option.is_bounded_by out_bounds (Interp rv) = true /\ Interp rv = InterpE. Lemma BoundsPipelineConst_correct_trans @@ -6302,7 +6331,7 @@ Module Pipeline. out_bounds rv (Hrv : BoundsPipelineConst_full (*with_dead_code_elimination*) with_subst01 relax_zrange E out_bounds = Success rv) - : ZRange.type.is_bounded_by out_bounds (Interp rv) = true + : ZRange.type.option.is_bounded_by out_bounds (Interp rv) = true /\ Interp rv = for_reification.Interp E. Proof. cbv [BoundsPipelineConst_full BoundsPipelineConst] in *. @@ -6370,17 +6399,17 @@ Section rcarry_mul. Let idxs := (seq 0 n ++ [0; 1])%list%nat. Let coef := 2. Let upperbound_tight := (2^Qceiling limbwidth + 2^(Qceiling limbwidth - 3))%Z. - Let prime_bound : ZRange.type.interp (type.Z) - := r[0~>(s - Associational.eval c - 1)]%zrange. + Let prime_bound : ZRange.type.option.interp (type.Z) + := Some r[0~>(s - Associational.eval c - 1)]%zrange. Definition relax_zrange_of_machine_wordsize := relax_zrange_gen [machine_wordsize; 2 * machine_wordsize]%Z. Let relax_zrange := relax_zrange_of_machine_wordsize. - Let tight_bounds : ZRange.type.interp (type.list type.Z) - := List.repeat r[0~>upperbound_tight]%zrange n. - Let loose_bounds : ZRange.type.interp (type.list type.Z) - := List.repeat r[0 ~> 3*upperbound_tight]%zrange n. + Let tight_bounds : list (ZRange.type.option.interp type.Z) + := List.repeat (Some r[0~>upperbound_tight]%zrange) n. + Let loose_bounds : list (ZRange.type.option.interp type.Z) + := List.repeat (Some r[0 ~> 3*upperbound_tight]%zrange) n. Definition check_args {T} (res : Pipeline.ErrorT T) : Pipeline.ErrorT T @@ -6442,59 +6471,59 @@ Section rcarry_mul. := BoundsPipeline (carry_mul_gen @ GallinaReify.Reify (Qnum limbwidth) @ GallinaReify.Reify (Z.pos (Qden limbwidth)) @ GallinaReify.Reify s @ GallinaReify.Reify c @ GallinaReify.Reify n @ GallinaReify.Reify (length c) @ GallinaReify.Reify idxs @ GallinaReify.Reify (length idxs)) - (loose_bounds, loose_bounds) - tight_bounds. + (Some loose_bounds, Some loose_bounds) + (Some tight_bounds). Definition rcarry_mul_correct := BoundsPipeline_correct - (loose_bounds, loose_bounds) - tight_bounds + (Some loose_bounds, Some loose_bounds) + (Some tight_bounds) (carry_mulmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) idxs (List.length idxs)). Definition rcarry_correct := BoundsPipeline_correct - loose_bounds - tight_bounds + (Some loose_bounds) + (Some tight_bounds) (carrymod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) idxs (List.length idxs)). Definition rrelax_correct := BoundsPipeline_correct - tight_bounds - loose_bounds + (Some tight_bounds) + (Some loose_bounds) (expanding_id n). Definition radd_correct := BoundsPipeline_correct - (tight_bounds, tight_bounds) - loose_bounds + (Some tight_bounds, Some tight_bounds) + (Some loose_bounds) (addmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) n). Definition rsub_correct := BoundsPipeline_correct - (tight_bounds, tight_bounds) - loose_bounds + (Some tight_bounds, Some tight_bounds) + (Some loose_bounds) (submod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) coef). Definition ropp_correct := BoundsPipeline_correct - tight_bounds - loose_bounds + (Some tight_bounds) + (Some loose_bounds) (oppmod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c) coef). Definition rencode_correct := BoundsPipeline_correct prime_bound - tight_bounds + (Some tight_bounds) (encodemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). Definition rzero_correct := BoundsPipelineConst_correct - tight_bounds + (Some tight_bounds) (zeromod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). Definition rone_correct := BoundsPipelineConst_correct - tight_bounds + (Some tight_bounds) (onemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)). (* we need to strip off [Hrv : ... = Pipeline.Success rv] and related arguments *) @@ -7493,7 +7522,7 @@ Module MontgomeryReduction. (machine_wordsize : Z). Let n : nat := Z.to_nat (Qceiling ((Z.log2_up N) / machine_wordsize)). - Let bound := r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. + Let bound := Some r[0 ~> (2^machine_wordsize - 1)%Z]%zrange. Definition relax_zrange_of_machine_wordsize := relax_zrange_gen [1; machine_wordsize / 2; machine_wordsize; 2 * machine_wordsize; 4 * machine_wordsize]%Z. -- cgit v1.2.3