diff options
Diffstat (limited to 'src/Experiments/NewPipeline/Toplevel2.v')
-rw-r--r-- | src/Experiments/NewPipeline/Toplevel2.v | 3395 |
1 files changed, 3395 insertions, 0 deletions
diff --git a/src/Experiments/NewPipeline/Toplevel2.v b/src/Experiments/NewPipeline/Toplevel2.v new file mode 100644 index 000000000..4cba170bd --- /dev/null +++ b/src/Experiments/NewPipeline/Toplevel2.v @@ -0,0 +1,3395 @@ +Require Import Coq.ZArith.ZArith Coq.micromega.Lia. +Require Import Coq.derive.Derive. +Require Import Coq.Bool.Bool. +Require Import Coq.Strings.String. +Require Import Coq.Lists.List. +Require Crypto.Util.Strings.String. +Require Import Crypto.Util.Strings.Decimal. +Require Import Crypto.Util.Strings.HexString. +Require Import QArith.QArith_base QArith.Qround Crypto.Util.QUtil. +Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop. +Require Import Crypto.Algebra.Ring. +Require Import Crypto.Algebra.SubsetoidRing. +Require Import Crypto.Util.ZRange. +Require Import Crypto.Util.ListUtil.FoldBool. +Require Import Crypto.Util.LetIn. +Require Import Crypto.Arithmetic.PrimeFieldTheorems. +Require Import Crypto.Util.ZUtil.Tactics.LtbToLt. +Require Import Crypto.Util.ZUtil.Tactics.PullPush.Modulo. +Require Import Crypto.Util.Tactics.SplitInContext. +Require Import Crypto.Util.Tactics.SubstEvars. +Require Import Crypto.Util.Tactics.DestructHead. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil Coq.Lists.List. +Require Import Crypto.Util.Equality. +Require Import Crypto.Util.Tactics.GetGoal. +Require Import Crypto.Arithmetic.BarrettReduction.Generalized. +Require Import Crypto.Util.Tactics.UniquePose. +Require Import Crypto.Util.ZUtil.Rshi. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Tactics.BreakMatch. +Require Import Crypto.Util.Tactics.SpecializeBy. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ZUtil.Zselect. +Require Import Crypto.Util.ZUtil.AddModulo. +Require Import Crypto.Util.ZUtil.CC. +Require Import Crypto.Arithmetic.MontgomeryReduction.Definition. +Require Import Crypto.Arithmetic.MontgomeryReduction.Proofs. +Require Import Crypto.Util.ErrorT. +Require Import Crypto.Util.Strings.Show. +Require Import Crypto.Util.ZRange.Show. +Require Import Crypto.Experiments.NewPipeline.Arithmetic. +Require Crypto.Experiments.NewPipeline.Language. +Require Crypto.Experiments.NewPipeline.UnderLets. +Require Crypto.Experiments.NewPipeline.AbstractInterpretation. +Require Crypto.Experiments.NewPipeline.AbstractInterpretationProofs. +Require Crypto.Experiments.NewPipeline.Rewriter. +Require Crypto.Experiments.NewPipeline.MiscCompilerPasses. +Require Crypto.Experiments.NewPipeline.CStringification. +Require Export Crypto.Experiments.NewPipeline.Toplevel1. +Require Import Crypto.Util.Notations. +Import ListNotations. Local Open Scope Z_scope. + +Import Associational Positional. + +Import + Crypto.Experiments.NewPipeline.Language + Crypto.Experiments.NewPipeline.UnderLets + Crypto.Experiments.NewPipeline.AbstractInterpretation + Crypto.Experiments.NewPipeline.AbstractInterpretationProofs + Crypto.Experiments.NewPipeline.Rewriter + Crypto.Experiments.NewPipeline.MiscCompilerPasses + Crypto.Experiments.NewPipeline.CStringification. + +Import + Language.Compilers + UnderLets.Compilers + AbstractInterpretation.Compilers + AbstractInterpretationProofs.Compilers + Rewriter.Compilers + MiscCompilerPasses.Compilers + CStringification.Compilers. + +Import Compilers.defaults. +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion QArith_base.inject_Z : Z >-> Q. +Notation "x" := (expr.Var x) (only printing, at level 9) : expr_scope. + +Import UnsaturatedSolinas. + +Module X25519_64. + Definition n := 5%nat. + Definition s := 2^255. + Definition c := [(1, 19)]. + Definition machine_wordsize := 64. + Local Notation tight_bounds := (tight_bounds n s c). + Local Notation loose_bounds := (loose_bounds n s c). + Local Notation prime_bound := (prime_bound s c). + + Derive base_51_relax + SuchThat (rrelax_correctT n s c machine_wordsize base_51_relax) + As base_51_relax_correct. + Proof. Time solve_rrelax machine_wordsize. Time Qed. + Derive base_51_carry_mul + SuchThat (rcarry_mul_correctT n s c machine_wordsize base_51_carry_mul) + As base_51_carry_mul_correct. + Proof. Time solve_rcarry_mul machine_wordsize. Time Qed. + Derive base_51_carry + SuchThat (rcarry_correctT n s c machine_wordsize base_51_carry) + As base_51_carry_correct. + Proof. Time solve_rcarry machine_wordsize. Time Qed. + Derive base_51_add + SuchThat (radd_correctT n s c machine_wordsize base_51_add) + As base_51_add_correct. + Proof. Time solve_radd machine_wordsize. Time Qed. + Derive base_51_sub + SuchThat (rsub_correctT n s c machine_wordsize base_51_sub) + As base_51_sub_correct. + Proof. Time solve_rsub machine_wordsize. Time Qed. + Derive base_51_opp + SuchThat (ropp_correctT n s c machine_wordsize base_51_opp) + As base_51_opp_correct. + Proof. Time solve_ropp machine_wordsize. Time Qed. + Derive base_51_encode + SuchThat (rencode_correctT n s c machine_wordsize base_51_encode) + As base_51_encode_correct. + Proof. Time solve_rencode machine_wordsize. Time Qed. + Derive base_51_zero + SuchThat (rzero_correctT n s c machine_wordsize base_51_zero) + As base_51_zero_correct. + Proof. Time solve_rzero machine_wordsize. Time Qed. + Derive base_51_one + SuchThat (rone_correctT n s c machine_wordsize base_51_one) + As base_51_one_correct. + Proof. Time solve_rone machine_wordsize. Time Qed. + Lemma base_51_curve_good + : check_args n s c machine_wordsize (Success tt) = Success tt. + Proof. vm_compute; reflexivity. Qed. + + Definition base_51_good : GoodT n s c + := Good n s c machine_wordsize + base_51_curve_good + base_51_carry_mul_correct + base_51_carry_correct + base_51_relax_correct + base_51_add_correct + base_51_sub_correct + base_51_opp_correct + base_51_zero_correct + base_51_one_correct + base_51_encode_correct. + + Print Assumptions base_51_good. + Import PrintingNotations. + Set Printing Width 80. + Open Scope string_scope. + Print base_51_carry_mul. +(*base_51_carry_mul = +fun var : type -> Type => +(λ x x0 : var (type.base (base.type.list (base.type.type_base base.type.Z))), + expr_let x1 := (uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ + ((uint64)(x[[1]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19) +₁₂₈ + ((uint64)(x[[2]]) *₁₂₈ ((uint64)(x0[[3]]) *₆₄ 19) +₁₂₈ + ((uint64)(x[[3]]) *₁₂₈ ((uint64)(x0[[2]]) *₆₄ 19) +₁₂₈ + (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[1]]) *₆₄ 19)))) in + expr_let x2 := (uint64)(x1 >> 51) +₁₂₈ + ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ + ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ + ((uint64)(x[[2]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19) +₁₂₈ + ((uint64)(x[[3]]) *₁₂₈ ((uint64)(x0[[3]]) *₆₄ 19) +₁₂₈ + (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[2]]) *₆₄ 19))))) in + expr_let x3 := (uint64)(x2 >> 51) +₁₂₈ + ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[2]]) +₁₂₈ + ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ + ((uint64)(x[[2]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ + ((uint64)(x[[3]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19) +₁₂₈ + (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[3]]) *₆₄ 19))))) in + expr_let x4 := (uint64)(x3 >> 51) +₁₂₈ + ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[3]]) +₁₂₈ + ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[2]]) +₁₂₈ + ((uint64)(x[[2]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ + ((uint64)(x[[3]]) *₁₂₈ (uint64)(x0[[0]]) +₁₂₈ + (uint64)(x[[4]]) *₁₂₈ ((uint64)(x0[[4]]) *₆₄ 19))))) in + expr_let x5 := (uint64)(x4 >> 51) +₁₂₈ + ((uint64)(x[[0]]) *₁₂₈ (uint64)(x0[[4]]) +₁₂₈ + ((uint64)(x[[1]]) *₁₂₈ (uint64)(x0[[3]]) +₁₂₈ + ((uint64)(x[[2]]) *₁₂₈ (uint64)(x0[[2]]) +₁₂₈ + ((uint64)(x[[3]]) *₁₂₈ (uint64)(x0[[1]]) +₁₂₈ + (uint64)(x[[4]]) *₁₂₈ (uint64)(x0[[0]]))))) in + expr_let x6 := ((uint64)(x1) & 2251799813685247) +₆₄ (uint64)(x5 >> 51) *₆₄ 19 in + expr_let x7 := (uint64)(x6 >> 51) +₆₄ ((uint64)(x2) & 2251799813685247) in + expr_let x8 := ((uint64)(x6) & 2251799813685247) in + expr_let x9 := ((uint64)(x7) & 2251799813685247) in + expr_let x10 := (uint64)(x7 >> 51) +₆₄ ((uint64)(x3) & 2251799813685247) in + expr_let x11 := ((uint64)(x4) & 2251799813685247) in + expr_let x12 := ((uint64)(x5) & 2251799813685247) in + [x8; x9; x10; x11; x12])%expr + : Expr + (type.base (base.type.list (base.type.type_base base.type.Z)) -> + type.base (base.type.list (base.type.type_base base.type.Z)) -> + type.base (base.type.list (base.type.type_base base.type.Z)))%ptype +*) + Print base_51_sub. + (* +base_51_sub = +fun var : type -> Type => +(λ x x0 : var (type.base (base.type.list (base.type.type_base base.type.Z))), + expr_let x1 := (4503599627370458 +₆₄ (uint64)(x[[0]])) -₆₄ (uint64)(x0[[0]]) in + expr_let x2 := (4503599627370494 +₆₄ (uint64)(x[[1]])) -₆₄ (uint64)(x0[[1]]) in + expr_let x3 := (4503599627370494 +₆₄ (uint64)(x[[2]])) -₆₄ (uint64)(x0[[2]]) in + expr_let x4 := (4503599627370494 +₆₄ (uint64)(x[[3]])) -₆₄ (uint64)(x0[[3]]) in + expr_let x5 := (4503599627370494 +₆₄ (uint64)(x[[4]])) -₆₄ (uint64)(x0[[4]]) in + [x1; x2; x3; x4; x5])%expr + : Expr + (type.base (base.type.list (base.type.type_base base.type.Z)) -> + type.base (base.type.list (base.type.type_base base.type.Z)) -> + type.base (base.type.list (base.type.type_base base.type.Z)))%ptype +*) + + Compute ToString.C.ToFunctionString + "fecarry_mul" base_51_carry_mul + None (Some loose_bounds, (Some loose_bounds, tt)). + (* +void fecarry_mul(uint64_t[5] x1, uint64_t[5] x2, uint64_t[5] x3) { + uint128_t x4 = (((uint128_t)(x1[0]) * (x2[0])) + (((uint128_t)(x1[1]) * ((x2[4]) * 0x13)) + (((uint128_t)(x1[2]) * ((x2[3]) * 0x13)) + (((uint128_t)(x1[3]) * ((x2[2]) * 0x13)) + ((uint128_t)(x1[4]) * ((x2[1]) * 0x13)))))); + uint128_t x5 = ((uint64_t)(x4 >> 51) + (((uint128_t)(x1[0]) * (x2[1])) + (((uint128_t)(x1[1]) * (x2[0])) + (((uint128_t)(x1[2]) * ((x2[4]) * 0x13)) + (((uint128_t)(x1[3]) * ((x2[3]) * 0x13)) + ((uint128_t)(x1[4]) * ((x2[2]) * 0x13))))))); + uint128_t x6 = ((uint64_t)(x5 >> 51) + (((uint128_t)(x1[0]) * (x2[2])) + (((uint128_t)(x1[1]) * (x2[1])) + (((uint128_t)(x1[2]) * (x2[0])) + (((uint128_t)(x1[3]) * ((x2[4]) * 0x13)) + ((uint128_t)(x1[4]) * ((x2[3]) * 0x13))))))); + uint128_t x7 = ((uint64_t)(x6 >> 51) + (((uint128_t)(x1[0]) * (x2[3])) + (((uint128_t)(x1[1]) * (x2[2])) + (((uint128_t)(x1[2]) * (x2[1])) + (((uint128_t)(x1[3]) * (x2[0])) + ((uint128_t)(x1[4]) * ((x2[4]) * 0x13))))))); + uint128_t x8 = ((uint64_t)(x7 >> 51) + (((uint128_t)(x1[0]) * (x2[4])) + (((uint128_t)(x1[1]) * (x2[3])) + (((uint128_t)(x1[2]) * (x2[2])) + (((uint128_t)(x1[3]) * (x2[1])) + ((uint128_t)(x1[4]) * (x2[0]))))))); + uint64_t x9 = ((uint64_t)(x4 & 0x7ffffffffffffUL) + ((uint64_t)(x8 >> 51) * 0x13)); + uint64_t x10 = ((x9 >> 51) + (uint64_t)(x5 & 0x7ffffffffffffUL)); + x3[0] = (x9 & 0x7ffffffffffffUL); + x3[1] = (x10 & 0x7ffffffffffffUL); + x3[2] = ((x10 >> 51) + (uint64_t)(x6 & 0x7ffffffffffffUL)); + x3[3] = (uint64_t)(x7 & 0x7ffffffffffffUL); + x3[4] = (uint64_t)(x8 & 0x7ffffffffffffUL); +} +*) + Compute ToString.C.ToFunctionString + "fesub" base_51_sub + None (Some tight_bounds, (Some tight_bounds, tt)). +(* +void fesub(uint64_t[5] x1, uint64_t[5] x2, uint64_t[5] x3) { + x3[0] = ((0xfffffffffffdaUL + (x1[0])) - (x2[0])); + x3[1] = ((0xffffffffffffeUL + (x1[1])) - (x2[1])); + x3[2] = ((0xffffffffffffeUL + (x1[2])) - (x2[2])); + x3[3] = ((0xffffffffffffeUL + (x1[3])) - (x2[3])); + x3[4] = ((0xffffffffffffeUL + (x1[4])) - (x2[4])); +} +*) +End X25519_64. + +Module P192_64. + Definition s := 2^192. + Definition c := [(2^64, 1); (1,1)]. + Definition machine_wordsize := 64. + + Derive mulmod + SuchThat (SaturatedSolinas.rmulmod_correctT s c machine_wordsize mulmod) + As mulmod_correct. + Proof. Time solve_rmulmod machine_wordsize. Time Qed. + + Import PrintingNotations. + Open Scope expr_scope. + Set Printing Width 100000. + Set Printing Depth 100000. + + Local Notation "'mul64' '(' x ',' y ')'" := + (#(Z_cast2 (uint64, _)%core) @ (#(Z_mul_split_concrete 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. + Local Notation "'add64' '(' x ',' y ')'" := + (#(Z_cast2 (uint64, bool)%core) @ (#(Z_add_get_carry_concrete 18446744073709551616) @ x @ y))%expr (at level 50) : expr_scope. + Local Notation "'adc64' '(' c ',' x ',' y ')'" := + (#(Z_cast2 (uint64, bool)%core) @ (#(Z_add_with_get_carry_concrete 18446744073709551616) @ c @ x @ y))%expr (at level 50) : expr_scope. + Local Notation "'adx64' '(' c ',' x ',' y ')'" := + (#(Z_cast bool) @ (#Z_add_with_carry @ c @ x @ y))%expr (at level 50) : expr_scope. + + Print mulmod. +(* +mulmod = fun var : type -> Type => λ x x0 : var (type.base (base.type.list (base.type.type_base base.type.Z))), + expr_let x1 := mul64 ((uint64)(x[[2]]), (uint64)(x0[[2]])) in + expr_let x2 := mul64 ((uint64)(x[[2]]), (uint64)(x0[[1]])) in + expr_let x3 := mul64 ((uint64)(x[[2]]), (uint64)(x0[[0]])) in + expr_let x4 := mul64 ((uint64)(x[[1]]), (uint64)(x0[[2]])) in + expr_let x5 := mul64 ((uint64)(x[[1]]), (uint64)(x0[[1]])) in + expr_let x6 := mul64 ((uint64)(x[[1]]), (uint64)(x0[[0]])) in + expr_let x7 := mul64 ((uint64)(x[[0]]), (uint64)(x0[[2]])) in + expr_let x8 := mul64 ((uint64)(x[[0]]), (uint64)(x0[[1]])) in + expr_let x9 := mul64 ((uint64)(x[[0]]), (uint64)(x0[[0]])) in + expr_let x10 := add64 (x1₂, x9₂) in + expr_let x11 := adc64 (x10₂, 0, x8₂) in + expr_let x12 := add64 (x1₁, x10₁) in + expr_let x13 := adc64 (x12₂, 0, x11₁) in + expr_let x14 := add64 (x2₂, x12₁) in + expr_let x15 := adc64 (x14₂, 0, x13₁) in + expr_let x16 := add64 (x4₂, x14₁) in + expr_let x17 := adc64 (x16₂, x1₂, x15₁) in + expr_let x18 := add64 (x2₁, x16₁) in + expr_let x19 := adc64 (x18₂, x1₁, x17₁) in + expr_let x20 := add64 (x1₂, x9₁) in + expr_let x21 := adc64 (x20₂, x3₂, x18₁) in + expr_let x22 := adc64 (x21₂, x2₂, x19₁) in + expr_let x23 := add64 (x2₁, x20₁) in + expr_let x24 := adc64 (x23₂, x4₁, x21₁) in + expr_let x25 := adc64 (x24₂, x4₂, x22₁) in + expr_let x26 := add64 (x3₂, x23₁) in + expr_let x27 := adc64 (x26₂, x5₂, x24₁) in + expr_let x28 := adc64 (x27₂, x3₁, x25₁) in + expr_let x29 := add64 (x4₁, x26₁) in + expr_let x30 := adc64 (x29₂, x7₂, x27₁) in + expr_let x31 := adc64 (x30₂, x5₁, x28₁) in + expr_let x32 := add64 (x5₂, x29₁) in + expr_let x33 := adc64 (x32₂, x6₁, x30₁) in + expr_let x34 := adc64 (x33₂, x6₂, x31₁) in + expr_let x35 := add64 (x7₂, x32₁) in + expr_let x36 := adc64 (x35₂, x8₁, x33₁) in + expr_let x37 := adc64 (x36₂, x7₁, x34₁) in + [x35₁; x36₁; x37₁] + : Expr (type.base (base.type.list (base.type.type_base base.type.Z)) -> type.base (base.type.list (base.type.type_base base.type.Z)) -> type.base (base.type.list (base.type.type_base base.type.Z)))%ptype +*) + +End P192_64. + +Module PreFancy. + Section with_wordmax. + Context (log2wordmax : Z) (log2wordmax_pos : 1 < log2wordmax) (log2wordmax_even : log2wordmax mod 2 = 0). + Let wordmax := 2 ^ log2wordmax. + Lemma wordmax_gt_2 : 2 < wordmax. + Proof. + apply Z.le_lt_trans with (m:=2 ^ 1); [ reflexivity | ]. + apply Z.pow_lt_mono_r; omega. + Qed. + + Lemma wordmax_even : wordmax mod 2 = 0. + Proof. + replace 2 with (2 ^ 1) by reflexivity. + subst wordmax. apply Z.mod_same_pow; omega. + Qed. + + Let half_bits := log2wordmax / 2. + + Lemma half_bits_nonneg : 0 <= half_bits. + Proof. subst half_bits; Z.zero_bounds. Qed. + + Let wordmax_half_bits := 2 ^ half_bits. + + Lemma wordmax_half_bits_pos : 0 < wordmax_half_bits. + Proof. subst wordmax_half_bits half_bits. Z.zero_bounds. Qed. + + Lemma half_bits_squared : (wordmax_half_bits - 1) * (wordmax_half_bits - 1) <= wordmax - 1. + Proof. + pose proof wordmax_half_bits_pos. + subst wordmax_half_bits. + transitivity (2 ^ (half_bits + half_bits) - 2 * 2 ^ half_bits + 1). + { rewrite Z.pow_add_r by (subst half_bits; Z.zero_bounds). + autorewrite with push_Zmul; omega. } + { transitivity (wordmax - 2 * 2 ^ half_bits + 1); [ | lia]. + subst wordmax. + apply Z.add_le_mono_r. + apply Z.sub_le_mono_r. + apply Z.pow_le_mono_r; [ omega | ]. + rewrite Z.add_diag; subst half_bits. + apply BinInt.Z.mul_div_le; omega. } + Qed. + + Lemma wordmax_half_bits_le_wordmax : wordmax_half_bits <= wordmax. + Proof. + subst wordmax half_bits wordmax_half_bits. + apply Z.pow_le_mono_r; [lia|]. + apply Z.div_le_upper_bound; lia. + Qed. + + Lemma ones_half_bits : wordmax_half_bits - 1 = Z.ones half_bits. + Proof. + subst wordmax_half_bits. cbv [Z.ones]. + rewrite Z.shiftl_mul_pow2, <-Z.sub_1_r by auto using half_bits_nonneg. + lia. + Qed. + + Lemma wordmax_half_bits_squared : wordmax_half_bits * wordmax_half_bits = wordmax. + Proof. + subst wordmax half_bits wordmax_half_bits. + rewrite <-Z.pow_add_r by Z.zero_bounds. + rewrite Z.add_diag, Z.mul_div_eq by omega. + f_equal; lia. + Qed. + +(* + Section interp. + Context {interp_cast : zrange -> Z -> Z}. + Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)). + Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)). + Local Notation low x := (Z.land x (wordmax_half_bits - 1)). + Local Notation high x := (x >> half_bits). + Local Notation shift x imm := ((x << imm) mod wordmax). + + Definition interp_ident {s d} (idc : ident s d) : type.interp s -> type.interp d := + match idc with + | add imm => fun x => Z.add_get_carry_full wordmax (fst x) (shift (snd x) imm) + | addc imm => fun x => Z.add_with_get_carry_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) + | sub imm => fun x => Z.sub_get_borrow_full wordmax (fst x) (shift (snd x) imm) + | subb imm => fun x => Z.sub_with_get_borrow_full wordmax (fst (fst x)) (snd (fst x)) (shift (snd x) imm) + | mulll => fun x => low (fst x) * low (snd x) + | mullh => fun x => low (fst x) * high (snd x) + | mulhl => fun x => high (fst x) * low (snd x) + | mulhh => fun x => high (fst x) * high (snd x) + | rshi n => fun x => Z.rshi wordmax (fst x) (snd x) n + | selc => fun x => Z.zselect (fst (fst x)) (snd (fst x)) (snd x) + | selm => fun x => Z.zselect (Z.cc_m wordmax (fst (fst x))) (snd (fst x)) (snd x) + | sell => fun x => Z.zselect (Z.land (fst (fst x)) 1) (snd (fst x)) (snd x) + | addm => fun x => Z.add_modulo (fst (fst x)) (snd (fst x)) (snd x) + end. + + Fixpoint interp {t} (e : @expr type.interp ident t) : type.interp t := + match e with + | Scalar t s => interp_scalar s + | LetInAppIdentZ s d r idc x f => + interp (f (interp_cast r (interp_ident idc (interp_scalar x)))) + | LetInAppIdentZZ s d r idc x f => + interp (f (interp_cast2 r (interp_ident idc (interp_scalar x)))) + end. + End interp. + + Section proofs. + Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) + (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). + Context {interp_cast : zrange -> Z -> Z} {interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x}. + Local Notation interp_scalar := (interp_scalar (interp_cast:=interp_cast)). + Local Notation interp_cast2 := (interp_cast2 (interp_cast:=interp_cast)). + + Local Notation word_range := (r[0~>wordmax-1])%zrange. + Local Notation half_word_range := (r[0~>wordmax_half_bits-1])%zrange. + Local Notation flag_range := (r[0~>1])%zrange. + + Definition in_word_range (r : zrange) := is_tighter_than_bool r word_range = true. + Definition in_flag_range (r : zrange) := is_tighter_than_bool r flag_range = true. + + Fixpoint get_range_var (t : type) : type.interp t -> range_type t := + match t with + | type.type_primitive type.Z => + fun x => {| lower := x; upper := x |} + | type.prod a b => + fun x => (get_range_var a (fst x), get_range_var b (snd x)) + | _ => fun _ => tt + end. + + Fixpoint get_range {t} (x : @scalar type.interp t) : range_type t := + match x with + | Var t v => get_range_var t v + | TT => tt + | Nil _ => tt + | Pair _ _ x y => (get_range x, get_range y) + | Cast r _ => r + | Cast2 r _ => r + | Fst _ _ p => fst (get_range p) + | Snd _ _ p => snd (get_range p) + | Shiftr n x => ZRange.map (fun y => Z.shiftr y n) (get_range x) + | Shiftl n x => ZRange.map (fun y => Z.shiftl y n) (get_range x) + | Land n x => r[0~>n]%zrange + | CC_m n x => ZRange.map (Z.cc_m n) (get_range x) + | Primitive type.Z x => {| lower := x; upper := x |} + | Primitive p x => tt + end. + + Fixpoint has_range {t} : range_type t -> type.interp t -> Prop := + match t with + | type.type_primitive type.Z => + fun r x => + lower r <= x <= upper r + | type.prod a b => + fun r x => + has_range (fst r) (fst x) /\ has_range (snd r) (snd x) + | _ => fun _ _ => True + end. + + Inductive ok_scalar : forall {t}, @scalar type.interp t -> Prop := + | sc_ok_var : forall t v, ok_scalar (Var t v) + | sc_ok_unit : ok_scalar TT + | sc_ok_nil : forall t, ok_scalar (Nil t) + | sc_ok_pair : forall A B x y, + @ok_scalar A x -> + @ok_scalar B y -> + ok_scalar (Pair x y) + | sc_ok_cast : forall r (x : scalar type.Z), + ok_scalar x -> + is_tighter_than_bool (get_range x) r = true -> + ok_scalar (Cast r x) + | sc_ok_cast2 : forall r (x : scalar (type.prod type.Z type.Z)), + ok_scalar x -> + is_tighter_than_bool (fst (get_range x)) (fst r) = true -> + is_tighter_than_bool (snd (get_range x)) (snd r) = true -> + ok_scalar (Cast2 r x) + | sc_ok_fst : + forall A B p, @ok_scalar (A * B) p -> ok_scalar (Fst p) + | sc_ok_snd : + forall A B p, @ok_scalar (A * B) p -> ok_scalar (Snd p) + | sc_ok_shiftr : + forall n x, 0 <= n -> ok_scalar x -> ok_scalar (Shiftr n x) + | sc_ok_shiftl : + forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Shiftl n x) + | sc_ok_land : + forall n x, 0 <= n -> 0 <= lower (@get_range type.Z x) -> ok_scalar x -> ok_scalar (Land n x) + | sc_ok_cc_m : + forall x, ok_scalar x -> ok_scalar (CC_m wordmax x) + | sc_ok_prim : forall p x, ok_scalar (@Primitive _ p x) + . + + Inductive is_halved : scalar type.Z -> Prop := + | is_halved_lower : + forall x : scalar type.Z, + in_word_range (get_range x) -> + is_halved (Cast half_word_range (Land (wordmax_half_bits - 1) x)) + | is_halved_upper : + forall x : scalar type.Z, + in_word_range (get_range x) -> + is_halved (Cast half_word_range (Shiftr half_bits x)) + | is_halved_constant : + forall y z, + constant_to_scalar consts z = Some y -> + is_halved y -> + is_halved (Primitive (t:=type.Z) z) + . + + Inductive ok_ident : forall s d, scalar s -> range_type d -> ident.ident s d -> Prop := + | ok_add : + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident _ + (type.prod type.Z type.Z) + (Pair x y) + (word_range, flag_range) + (ident.Z.add_get_carry_concrete wordmax) + | ok_addc : + forall (c x y : scalar type.Z) outr, + in_flag_range (get_range c) -> + in_word_range (get_range x) -> + in_word_range (get_range y) -> + lower outr = 0 -> + (0 <= upper (get_range c) + upper (get_range x) + upper (get_range y) <= upper outr \/ outr = word_range) -> + ok_ident _ + (type.prod type.Z type.Z) + (Pair (Pair c x) y) + (outr, flag_range) + (ident.Z.add_with_get_carry_concrete wordmax) + | ok_sub : + forall x y : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident _ + (type.prod type.Z type.Z) + (Pair x y) + (word_range, flag_range) + (ident.Z.sub_get_borrow_concrete wordmax) + | ok_subb : + forall b x y : scalar type.Z, + in_flag_range (get_range b) -> + in_word_range (get_range x) -> + in_word_range (get_range y) -> + ok_ident _ + (type.prod type.Z type.Z) + (Pair (Pair b x) y) + (word_range, flag_range) + (ident.Z.sub_with_get_borrow_concrete wordmax) + | ok_rshi : + forall (x : scalar (type.prod type.Z type.Z)) n outr, + in_word_range (fst (get_range x)) -> + in_word_range (snd (get_range x)) -> + (* note : using [outr] rather than [word_range] allows for cases where the result has been put in a smaller word size. *) + lower outr = 0 -> + 0 <= n -> + ((0 <= (upper (snd (get_range x)) + upper (fst (get_range x)) * wordmax) / 2^n <= upper outr) + \/ outr = word_range) -> + ok_ident (type.prod type.Z type.Z) type.Z x outr (ident.Z.rshi_concrete wordmax n) + | ok_selc : + forall (x : scalar (type.prod type.Z type.Z)) (y z : scalar type.Z), + in_flag_range (snd (get_range x)) -> + in_word_range (get_range y) -> + in_word_range (get_range z) -> + ok_ident _ + type.Z + (Pair (Pair (Cast flag_range (Snd x)) y) z) + word_range + ident.Z.zselect + | ok_selm : + forall x y z : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + in_word_range (get_range z) -> + ok_ident _ + type.Z + (Pair (Pair (Cast flag_range (CC_m wordmax x)) y) z) + word_range + ident.Z.zselect + | ok_sell : + forall x y z : scalar type.Z, + in_word_range (get_range x) -> + in_word_range (get_range y) -> + in_word_range (get_range z) -> + ok_ident _ + type.Z + (Pair (Pair (Cast flag_range (Land 1 x)) y) z) + word_range + ident.Z.zselect + | ok_addm : + forall (x : scalar (type.prod (type.prod type.Z type.Z) type.Z)), + in_word_range (fst (fst (get_range x))) -> + in_word_range (snd (fst (get_range x))) -> + in_word_range (snd (get_range x)) -> + upper (fst (fst (get_range x))) + upper (snd (fst (get_range x))) - lower (snd (get_range x)) < wordmax -> + ok_ident _ + type.Z + x + word_range + ident.Z.add_modulo + | ok_mul : + forall x y : scalar type.Z, + is_halved x -> + is_halved y -> + ok_ident (type.prod type.Z type.Z) + type.Z + (Pair x y) + word_range + ident.Z.mul + . + + Inductive ok_expr : forall {t}, @expr type.interp ident.ident t -> Prop := + | ok_of_scalar : forall t s, ok_scalar s -> @ok_expr t (Scalar s) + | ok_letin_z : forall s d r idc x f, + ok_ident _ type.Z x r idc -> + (r <=? word_range)%zrange = true -> + ok_scalar x -> + (forall y, has_range (t:=type.Z) r y -> ok_expr (f y)) -> + ok_expr (@LetInAppIdentZ _ _ s d r idc x f) + | ok_letin_zz : forall s d r idc x f, + ok_ident _ (type.prod type.Z type.Z) x (r, flag_range) idc -> + (r <=? word_range)%zrange = true -> + ok_scalar x -> + (forall y, has_range (t:=type.Z * type.Z) (r, flag_range) y -> ok_expr (f y)) -> + ok_expr (@LetInAppIdentZZ _ _ s d (r, flag_range) idc x f) + . + + Ltac invert H := + inversion H; subst; + repeat match goal with + | H : existT _ _ _ = existT _ _ _ |- _ => apply (Eqdep_dec.inj_pair2_eq_dec _ type.type_eq_dec) in H; subst + end. + + Lemma has_range_get_range_var {t} (v : type.interp t) : + has_range (get_range_var _ v) v. + Proof. + induction t; cbn [get_range_var has_range fst snd]; auto. + destruct p; auto; cbn [upper lower]; omega. + Qed. + + Lemma has_range_loosen r1 r2 (x : Z) : + @has_range type.Z r1 x -> + is_tighter_than_bool r1 r2 = true -> + @has_range type.Z r2 x. + Proof. + cbv [is_tighter_than_bool has_range]; intros; + match goal with H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H end; + Z.ltb_to_lt; omega. + Qed. + + Lemma interp_cast_noop x r : + @has_range type.Z r x -> + interp_cast r x = x. + Proof. cbv [has_range]; intros; auto. Qed. + + Lemma interp_cast2_noop x r : + @has_range (type.prod type.Z type.Z) r x -> + interp_cast2 r x = x. + Proof. + cbv [has_range interp_cast2]; intros. + rewrite !interp_cast_correct by tauto. + destruct x; reflexivity. + Qed. + + Lemma has_range_shiftr n (x : scalar type.Z) : + 0 <= n -> + has_range (get_range x) (interp_scalar x) -> + @has_range type.Z (ZRange.map (fun y : Z => y >> n) (get_range x)) (interp_scalar x >> n). + Proof. cbv [has_range]; intros; cbn. auto using Z.shiftr_le with omega. Qed. + Hint Resolve has_range_shiftr : has_range. + + Lemma has_range_shiftl n r x : + 0 <= n -> 0 <= lower r -> + @has_range type.Z r x -> + @has_range type.Z (ZRange.map (fun y : Z => y << n) r) (x << n). + Proof. cbv [has_range]; intros; cbn. auto using Z.shiftl_le_mono with omega. Qed. + Hint Resolve has_range_shiftl : has_range. + + Lemma has_range_land n (x : scalar type.Z) : + 0 <= n -> 0 <= lower (get_range x) -> + has_range (get_range x) (interp_scalar x) -> + @has_range type.Z (r[0~>n])%zrange (Z.land (interp_scalar x) n). + Proof. + cbv [has_range]; intros; cbn. + split; [ apply Z.land_nonneg | apply Z.land_upper_bound_r ]; omega. + Qed. + Hint Resolve has_range_land : has_range. + + Lemma has_range_interp_scalar {t} (x : scalar t) : + ok_scalar x -> + has_range (get_range x) (interp_scalar x). + Proof. + induction 1; cbn [interp_scalar get_range]; + auto with has_range; + try solve [try inversion IHok_scalar; cbn [has_range]; + auto using has_range_get_range_var]; [ | | | ]. + { rewrite interp_cast_noop by eauto using has_range_loosen. + eapply has_range_loosen; eauto. } + { inversion IHok_scalar. + rewrite interp_cast2_noop; + cbn [has_range]; split; eapply has_range_loosen; eauto. } + { cbn. cbv [has_range] in *. + pose proof wordmax_gt_2. + rewrite !Z.cc_m_eq by omega. + split; apply Z.div_le_mono; Z.zero_bounds; omega. } + { destruct p; cbn [has_range upper lower]; auto; omega. } + Qed. + Hint Resolve has_range_interp_scalar : has_range. + + Lemma has_word_range_interp_scalar (x : scalar type.Z) : + ok_scalar x -> + in_word_range (get_range x) -> + @has_range type.Z word_range (interp_scalar x). + Proof. eauto using has_range_loosen, has_range_interp_scalar. Qed. + + Lemma in_word_range_nonneg r : in_word_range r -> 0 <= lower r. + Proof. + cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff; intuition. + Qed. + + Lemma in_word_range_upper_nonneg r x : @has_range type.Z r x -> in_word_range r -> 0 <= upper r. + Proof. + cbv [in_word_range is_tighter_than_bool]; cbn. + rewrite andb_true_iff; intuition. + Z.ltb_to_lt. omega. + Qed. + + Lemma has_word_range_shiftl n r x : + 0 <= n -> upper r * 2 ^ n <= wordmax - 1 -> + @has_range type.Z r x -> + in_word_range r -> + @has_range type.Z word_range (x << n). + Proof. + intros. + eapply has_range_loosen; + [ apply has_range_shiftl; eauto using in_word_range_nonneg with has_range; omega | ]. + cbv [is_tighter_than_bool]. cbn. + apply andb_true_iff; split; apply Z.leb_le; + [ apply Z.shiftl_nonneg; solve [auto using in_word_range_nonneg] | ]. + rewrite Z.shiftl_mul_pow2 by omega. + auto. + Qed. + + Lemma has_range_rshi r n x y : + 0 <= n -> + 0 <= x -> + 0 <= y -> + lower r = 0 -> + (0 <= (y + x * wordmax) / 2^n <= upper r \/ r = word_range) -> + @has_range type.Z r (Z.rshi wordmax x y n). + Proof. + pose proof wordmax_gt_2. + intros. cbv [has_range]. + rewrite Z.rshi_correct by omega. + match goal with |- context [?x mod ?m] => + pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + split; [lia|]. + intuition. + { destruct (Z_lt_dec (upper r) wordmax); [ | lia]. + rewrite Z.mod_small by (split; Z.zero_bounds; omega). + omega. } + { subst r. cbn [upper]. omega. } + Qed. + + Lemma in_word_range_spec r : + (0 <= lower r /\ upper r <= wordmax - 1) + <-> in_word_range r. + Proof. + intros; cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff. + intuition; apply Z.leb_le; cbn [upper lower]; try omega. + Qed. + + Ltac destruct_scalar := + match goal with + | x : scalar (type.prod (type.prod _ _) _) |- _ => + match goal with |- context [interp_scalar x] => + destruct (interp_scalar x) as [ [? ?] ?]; + destruct (get_range x) as [ [? ?] ?] + end + | x : scalar (type.prod _ _) |- _ => + match goal with |- context [interp_scalar x] => + destruct (interp_scalar x) as [? ?]; destruct (get_range x) as [? ?] + end + end. + + Ltac extract_ok_scalar' level x := + match goal with + | H : ok_scalar (Pair (Pair (?f (?g x)) _) _) |- _ => + match (eval compute in (4 <=? level)) with + | true => invert H; extract_ok_scalar' 3 x + | _ => fail + end + | H : ok_scalar (Pair (?f (?g x)) _) |- _ => + match (eval compute in (3 <=? level)) with + | true => invert H; extract_ok_scalar' 2 x + | _ => fail + end + | H : ok_scalar (Pair _ (?f (?g x))) |- _ => + match (eval compute in (3 <=? level)) with + | true => invert H; extract_ok_scalar' 2 x + | _ => fail + end + | H : ok_scalar (?f (?g x)) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end + | H : ok_scalar (Pair (Pair x _) _) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end + | H : ok_scalar (Pair (Pair _ x) _) |- _ => + match (eval compute in (2 <=? level)) with + | true => invert H; extract_ok_scalar' 1 x + | _ => fail + end + | H : ok_scalar (?g x) |- _ => invert H + | H : ok_scalar (Pair x _) |- _ => invert H + | H : ok_scalar (Pair _ x) |- _ => invert H + end. + + Ltac extract_ok_scalar := + match goal with |- ok_scalar ?x => extract_ok_scalar' 4 x; assumption end. + + Lemma has_half_word_range_shiftr r x : + in_word_range r -> + @has_range type.Z r x -> + @has_range type.Z half_word_range (x >> half_bits). + Proof. + cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff. + cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt. + { apply Z.shiftr_nonneg. omega. } + { pose proof half_bits_nonneg. + pose proof half_bits_squared. + assert (x >> half_bits < wordmax_half_bits); [|omega]. + rewrite Z.shiftr_div_pow2 by auto. + apply Z.div_lt_upper_bound; Z.zero_bounds. + subst wordmax_half_bits half_bits. + rewrite <-Z.pow_add_r by omega. + rewrite Z.add_diag, Z.mul_div_eq, log2wordmax_even by omega. + autorewrite with zsimplify_fast. subst wordmax. omega. } + Qed. + + Lemma has_half_word_range_land r x : + in_word_range r -> + @has_range type.Z r x -> + @has_range type.Z half_word_range (x &' (wordmax_half_bits - 1)). + Proof. + pose proof wordmax_half_bits_pos. + cbv [in_word_range is_tighter_than_bool]. + rewrite andb_true_iff. + cbn [has_range upper lower]; intros; intuition; Z.ltb_to_lt. + { apply Z.land_nonneg; omega. } + { apply Z.land_upper_bound_r; omega. } + Qed. + + Section constant_to_scalar. + Lemma constant_to_scalar_single_correct s x z : + 0 <= x <= wordmax - 1 -> + constant_to_scalar_single x z = Some s -> interp_scalar s = z. + Proof. + cbv [constant_to_scalar_single]. + break_match; try discriminate; intros; Z.ltb_to_lt; subst; + try match goal with H : Some _ = Some _ |- _ => inversion H; subst end; + cbn [interp_scalar]; apply interp_cast_noop. + { apply has_half_word_range_shiftr with (r:=r[x~>x]%zrange); + cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega. + apply andb_true_iff; split; apply Z.leb_le; omega. } + { apply has_half_word_range_land with (r:=r[x~>x]%zrange); + cbv [in_word_range is_tighter_than_bool upper lower has_range]; try omega. + apply andb_true_iff; split; apply Z.leb_le; omega. } + Qed. + + Lemma constant_to_scalar_correct s z : + constant_to_scalar consts z = Some s -> interp_scalar s = z. + Proof. + cbv [constant_to_scalar]. + apply fold_right_invariant; try discriminate. + intros until 2; break_match; eauto using constant_to_scalar_single_correct. + Qed. + + Lemma constant_to_scalar_single_cases x y z : + @constant_to_scalar_single type.interp x z = Some y -> + (y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive (t:=type.Z) x))) + \/ (y = Cast half_word_range (Shiftr half_bits (Primitive (t:=type.Z) x))). + Proof. + cbv [constant_to_scalar_single]. + break_match; try discriminate; intros; Z.ltb_to_lt; subst; + try match goal with H : Some _ = Some _ |- _ => inversion H; subst end; + tauto. + Qed. + + Lemma constant_to_scalar_cases y z : + @constant_to_scalar type.interp consts z = Some y -> + (exists x, + @has_range type.Z word_range x + /\ y = Cast half_word_range (Land (wordmax_half_bits - 1) (Primitive x))) + \/ (exists x, + @has_range type.Z word_range x + /\ y = Cast half_word_range (Shiftr half_bits (Primitive x))). + Proof. + cbv [constant_to_scalar]. + apply fold_right_invariant; try discriminate. + intros until 2; break_match; eauto; intros. + match goal with H : constant_to_scalar_single _ _ = _ |- _ => + destruct (constant_to_scalar_single_cases _ _ _ H); subst end. + { left; eexists; split; eauto. + apply consts_ok; auto. } + { right; eexists; split; eauto. + apply consts_ok; auto. } + Qed. + + Lemma ok_scalar_constant_to_scalar y z : constant_to_scalar consts z = Some y -> ok_scalar y. + Proof. + pose proof wordmax_half_bits_pos. pose proof half_bits_nonneg. + let H := fresh in + intro H; apply constant_to_scalar_cases in H; destruct H as [ [? ?] | [? ?] ]; intuition; subst; + cbn [has_range lower upper] in *; repeat constructor; cbn [lower get_range]; try apply Z.leb_refl; try omega. + assert (in_word_range r[x~>x]) by (apply in_word_range_spec; cbn [lower upper]; omega). + pose proof (has_half_word_range_shiftr r[x~>x] x ltac:(assumption) ltac:(cbv [has_range lower upper]; omega)). + cbn [has_range ZRange.map is_tighter_than_bool lower upper] in *. + apply andb_true_iff; cbn [lower upper]; split; apply Z.leb_le; omega. + Qed. + End constant_to_scalar. + Hint Resolve ok_scalar_constant_to_scalar. + + Lemma is_halved_has_range x : + ok_scalar x -> + is_halved x -> + @has_range type.Z half_word_range (interp_scalar x). + Proof. + intro; pose proof (has_range_interp_scalar x ltac:(assumption)). + induction 1; cbn [interp_scalar] in *; intros; try assumption; [ ]. + rewrite <-(constant_to_scalar_correct y z) by assumption. + eauto using has_range_interp_scalar. + Qed. + + Lemma ident_interp_has_range s d x r idc: + ok_scalar x -> + ok_ident s d x r idc -> + has_range r (ident.interp idc (interp_scalar x)). + Proof. + intro. + pose proof (has_range_interp_scalar x ltac:(assumption)). + pose proof wordmax_gt_2. + induction 1; cbn [ident.interp ident.gen_interp]; intros; try destruct_scalar; + repeat match goal with + | H : _ && _ = true |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt + | H : _ /\ _ |- _ => destruct H + | H : is_halved _ |- _ => apply is_halved_has_range in H; [ | extract_ok_scalar ] + | _ => progress subst + | _ => progress (cbv [in_word_range in_flag_range is_tighter_than_bool] in * ) + | _ => progress (cbn [interp_scalar get_range has_range upper lower fst snd] in * ) + end. + { + autorewrite with to_div_mod. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_between_0_if by omega. + split; break_match; lia. } + { + autorewrite with to_div_mod. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_between_0_if by omega. + match goal with H : _ \/ _ |- _ => destruct H; subst end. + { split; break_match; try lia. + destruct (Z_lt_dec (upper outr) wordmax). + { match goal with |- _ <= ?y mod _ <= ?u => + assert (y <= u) by nia end. + rewrite Z.mod_small by omega. omega. } + { match goal with|- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + omega. } } + { split; break_match; cbn; lia. } } + { + autorewrite with to_div_mod. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_sub_small by omega. + split; break_match; lia. } + { + autorewrite with to_div_mod. + match goal with |- context [?a - ?b - ?c] => replace (a - b - c) with (a - (b + c)) by ring end. + match goal with |- context[?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite Z.div_sub_small by omega. + split; break_match; lia. } + { apply has_range_rshi; try nia; [ ]. + match goal with H : context [upper ?ra + upper ?rb * wordmax] |- context [?a + ?b * wordmax] => + assert ((a + b * wordmax) / 2^n <= (upper ra + upper rb * wordmax) / 2^n) by (apply Z.div_le_mono; Z.zero_bounds; nia) + end. + match goal with H : _ \/ ?P |- _ \/ ?P => destruct H; [left|tauto] end. + split; Z.zero_bounds; nia. } + { rewrite Z.zselect_correct. break_match; omega. } + { cbn [interp_scalar fst snd get_range] in *. + rewrite Z.zselect_correct. break_match; omega. } + { cbn [interp_scalar fst snd get_range] in *. + rewrite Z.zselect_correct. break_match; omega. } + { rewrite Z.add_modulo_correct. + break_match; Z.ltb_to_lt; omega. } + { cbn [interp_scalar has_range fst snd get_range upper lower] in *. + pose proof half_bits_squared. nia. } + Qed. + + Lemma has_flag_range_cc_m r x : + @has_range type.Z r x -> + in_word_range r -> + @has_range type.Z flag_range (Z.cc_m wordmax x). + Proof. + cbv [has_range in_word_range is_tighter_than_bool]. + cbn [upper lower]; rewrite andb_true_iff; intros. + match goal with H : _ /\ _ |- _ => destruct H; Z.ltb_to_lt end. + pose proof wordmax_gt_2. pose proof wordmax_even. + pose proof (Z.cc_m_small wordmax x). omega. + Qed. + + Lemma has_flag_range_cc_m' (x : scalar type.Z) : + ok_scalar x -> + in_word_range (get_range x) -> + @has_range type.Z flag_range (Z.cc_m wordmax (interp_scalar x)). + Proof. eauto using has_flag_range_cc_m with has_range. Qed. + + Lemma has_flag_range_land r x : + @has_range type.Z r x -> + in_word_range r -> + @has_range type.Z flag_range (Z.land x 1). + Proof. + cbv [has_range in_word_range is_tighter_than_bool]. + cbn [upper lower]; rewrite andb_true_iff; intuition; Z.ltb_to_lt. + { apply Z.land_nonneg. left; omega. } + { apply Z.land_upper_bound_r; omega. } + Qed. + + Lemma has_flag_range_land' (x : scalar type.Z) : + ok_scalar x -> + in_word_range (get_range x) -> + @has_range type.Z flag_range (Z.land (interp_scalar x) 1). + Proof. eauto using has_flag_range_land with has_range. Qed. + + Ltac rewrite_cast_noop_in_mul := + repeat match goal with + | _ => rewrite interp_cast_noop with (r:=half_word_range) in * + by (eapply has_range_loosen; auto using has_range_land, has_range_interp_scalar) + | _ => rewrite interp_cast_noop with (r:=half_word_range) in * + by (eapply has_range_loosen; try apply has_range_shiftr; auto using has_range_interp_scalar; + cbn [ZRange.map get_range] in *; auto) + | _ => rewrite interp_cast_noop by assumption + end. + + Lemma is_halved_cases x : + is_halved x -> + ok_scalar x -> + (exists y, + invert_lower consts x = Some y + /\ invert_upper consts x = None + /\ interp_scalar y &' (wordmax_half_bits - 1) = interp_scalar x) + \/ (exists y, + invert_lower consts x = None + /\ invert_upper consts x = Some y + /\ interp_scalar y >> half_bits = interp_scalar x). + Proof. + induction 1; intros; cbn; rewrite ?Z.eqb_refl; cbn. + { left. eexists; repeat split; auto. + rewrite interp_cast_noop; [ reflexivity | ]. + apply has_half_word_range_land with (r:=get_range x); auto. + apply has_range_interp_scalar; extract_ok_scalar. } + { right. eexists; repeat split; auto. + rewrite interp_cast_noop; [ reflexivity | ]. + apply has_half_word_range_shiftr with (r:=get_range x); auto. + apply has_range_interp_scalar; extract_ok_scalar. } + { match goal with H : constant_to_scalar _ _ = Some _ |- _ => + rewrite H; + let P := fresh in + destruct (constant_to_scalar_cases _ _ H) as [ [? [? ?] ] | [? [? ?] ] ]; + subst; cbn; rewrite ?Z.eqb_refl; cbn + end. + { left; eexists; repeat split; auto. + erewrite <-constant_to_scalar_correct by eassumption. + subst. cbn. + rewrite interp_cast_noop; [ reflexivity | ]. + eapply has_half_word_range_land with (r:=word_range); auto. + cbv [in_word_range is_tighter_than_bool]. + rewrite !Z.leb_refl; reflexivity. } + { right; eexists; repeat split; auto. + erewrite <-constant_to_scalar_correct by eassumption. + subst. cbn. + rewrite interp_cast_noop; [ reflexivity | ]. + eapply has_half_word_range_shiftr with (r:=word_range); auto. + cbv [in_word_range is_tighter_than_bool]. + rewrite !Z.leb_refl; reflexivity. } } + Qed. + + Lemma halved_mul_range x y : + ok_scalar (Pair x y) -> + is_halved x -> + is_halved y -> + 0 <= interp_scalar x * interp_scalar y < wordmax. + Proof. + intro Hok; invert Hok. intros. + repeat match goal with H : _ |- _ => apply is_halved_has_range in H; [|assumption] end. + cbv [has_range lower upper] in *. + pose proof half_bits_squared. nia. + Qed. + + Lemma of_straightline_ident_mul_correct r t x y g : + is_halved x -> + is_halved y -> + ok_scalar (Pair x y) -> + (word_range <=? r)%zrange = true -> + @has_range type.Z word_range (ident.interp ident.Z.mul (interp_scalar (Pair x y))) -> + @interp interp_cast _ (of_straightline_ident dummy_arrow consts ident.Z.mul t r (Pair x y) g) = + @interp interp_cast _ (g (ident.interp ident.Z.mul (interp_scalar (Pair x y)))). + Proof. + intros Hx Hy Hok ? ?; invert Hok; cbn [interp_scalar of_straightline_ident]; + destruct (is_halved_cases x Hx ltac:(assumption)) as [ [? [Pxlow [Pxhigh Pxi] ] ] | [? [Pxlow [Pxhigh Pxi] ] ] ]; + rewrite ?Pxlow, ?Pxhigh; + destruct (is_halved_cases y Hy ltac:(assumption)) as [ [? [Pylow [Pyhigh Pyi] ] ] | [? [Pylow [Pyhigh Pyi] ] ] ]; + rewrite ?Pylow, ?Pyhigh; + cbn; rewrite Pxi, Pyi; assert (0 <= interp_scalar x * interp_scalar y < wordmax) by (auto using halved_mul_range); + rewrite interp_cast_noop by (cbv [is_tighter_than_bool] in *; cbn [has_range upper lower] in *; rewrite andb_true_iff in *; intuition; Z.ltb_to_lt; lia); reflexivity. + Qed. + + Lemma has_word_range_mod_small x: + @has_range type.Z word_range x -> + x mod wordmax = x. + Proof. + cbv [has_range upper lower]. + intros. apply Z.mod_small; omega. + Qed. + + Lemma half_word_range_le_word_range r : + upper r = wordmax_half_bits - 1 -> + lower r = 0 -> + (r <=? word_range)%zrange = true. + Proof. + pose proof wordmax_half_bits_le_wordmax. + destruct r; cbv [is_tighter_than_bool ZRange.lower ZRange.upper]. + intros; subst. + apply andb_true_iff; split; Z.ltb_to_lt; lia. + Qed. + + Lemma and_shiftl_half_bits_eq x : + (x &' (wordmax_half_bits - 1)) << half_bits = x << half_bits mod wordmax. + Proof. + rewrite ones_half_bits. + rewrite Z.land_ones, !Z.shiftl_mul_pow2 by auto using half_bits_nonneg. + rewrite <-wordmax_half_bits_squared. + subst wordmax_half_bits. + rewrite Z.mul_mod_distr_r_full. + reflexivity. + Qed. + + Lemma in_word_range_word_range : in_word_range word_range. + Proof. + cbv [in_word_range is_tighter_than_bool]. + rewrite !Z.leb_refl; reflexivity. + Qed. + + Lemma invert_shift_correct (s : scalar type.Z) x imm : + ok_scalar s -> + invert_shift consts s = Some (x, imm) -> + interp_scalar s = (interp_scalar x << imm) mod wordmax. + Proof. + intros Hok ?; invert Hok; + try match goal with H : ok_scalar ?x, H' : context[Cast _ ?x] |- _ => + invert H end; + try match goal with H : ok_scalar ?x, H' : context[Shiftl _ ?x] |- _ => + invert H end; + try match goal with H : ok_scalar ?x, H' : context[Shiftl _ (Cast _ ?x)] |- _ => + invert H end; + try (cbn [invert_shift invert_upper invert_upper'] in *; discriminate); + repeat match goal with + | _ => progress (cbn [invert_shift invert_lower invert_lower' invert_upper invert_upper' interp_scalar fst snd] in * ) + | _ => rewrite interp_cast_noop by eauto using has_half_word_range_land, has_half_word_range_shiftr, in_word_range_word_range, has_range_loosen + | H : ok_scalar (Shiftr _ _) |- _ => apply has_range_interp_scalar in H + | H : ok_scalar (Shiftl _ _) |- _ => apply has_range_interp_scalar in H + | H : ok_scalar (Land _ _) |- _ => apply has_range_interp_scalar in H + | H : context [if ?x then _ else _] |- _ => + let Heq := fresh in case_eq x; intro Heq; rewrite Heq in H + | H : context [match @constant_to_scalar ?v ?consts ?x with _ => _ end] |- _ => + let Heq := fresh in + case_eq (@constant_to_scalar v consts x); intros until 0; intro Heq; rewrite Heq in *; [|discriminate]; + destruct (constant_to_scalar_cases _ _ Heq) as [ [? [? ?] ] | [? [? ?] ] ]; subst; + pose proof (ok_scalar_constant_to_scalar _ _ Heq) + | H : constant_to_scalar _ _ = Some _ |- _ => erewrite <-(constant_to_scalar_correct _ _ H) + | H : _ |- _ => rewrite andb_true_iff in H; destruct H; Z.ltb_to_lt + | H : Some _ = Some _ |- _ => progress (invert H) + | _ => rewrite has_word_range_mod_small by eauto using has_range_loosen, half_word_range_le_word_range + | _ => rewrite has_word_range_mod_small by + (eapply has_range_loosen with (r1:=half_word_range); + [ eapply has_half_word_range_shiftr with (r:=word_range) | ]; + eauto using in_word_range_word_range, half_word_range_le_word_range) + | _ => rewrite and_shiftl_half_bits_eq + | _ => progress subst + | _ => reflexivity + | _ => discriminate + end. + Qed. + + Local Ltac solve_commutative_replace := + match goal with + | |- @eq (_ * _) ?x ?y => + replace x with (fst x, snd x) by (destruct x; reflexivity); + replace y with (fst y, snd y) by (destruct y; reflexivity) + end; autorewrite with to_div_mod; solve [repeat (f_equal; try ring)]. + + Fixpoint is_tighter_than_bool_range_type t : range_type t -> range_type t -> bool := + match t with + | type.type_primitive type.Z => (fun r1 r2 => (r1 <=? r2)%zrange) + | type.prod a b => fun r1 r2 => + (is_tighter_than_bool_range_type a (fst r1) (fst r2)) + && (is_tighter_than_bool_range_type b (snd r1) (snd r2)) + | _ => fun _ _ => true + end. + + Definition range_ok {t} : range_type t -> Prop := + match t with + | type.type_primitive type.Z => fun r => in_word_range r + | type.prod type.Z type.Z => fun r => in_word_range (fst r) /\ snd r = flag_range + | _ => fun _ => False + end. + + Lemma of_straightline_ident_correct s d t x r r' (idc : ident.ident s d) g : + ok_ident s d x r idc -> + range_ok r' -> + is_tighter_than_bool_range_type d r r' = true -> + ok_scalar x -> + @interp interp_cast _ (of_straightline_ident dummy_arrow consts idc t r' x g) = + @interp interp_cast _ (g (ident.interp idc (interp_scalar x))). + Proof. + intros. + pose proof wordmax_half_bits_pos. + pose proof (ident_interp_has_range _ _ x r idc ltac:(assumption) ltac:(assumption)). + match goal with H : ok_ident _ _ _ _ _ |- _ => induction H end; + try solve [auto using of_straightline_ident_mul_correct]; + cbv [is_tighter_than_bool_range_type is_tighter_than_bool range_ok] in *; + cbn [of_straightline_ident ident.interp ident.gen_interp + invert_selm invert_sell] in *; + intros; rewrite ?Z.eqb_refl; cbn [andb]; + try match goal with |- context [invert_shift] => break_match end; + cbn [interp interp_ident]; try destruct_scalar; + repeat match goal with + | _ => progress (cbn [fst snd interp_scalar] in * ) + | _ => progress break_match; [ ] + | _ => progress autorewrite with zsimplify_fast + | _ => progress Z.ltb_to_lt + | H : _ /\ _ |- _ => destruct H + | _ => rewrite andb_true_iff in * + | _ => rewrite interp_cast_noop with (r:=flag_range) in * + by (apply has_flag_range_cc_m'; auto; extract_ok_scalar) + | _ => rewrite interp_cast_noop with (r:=flag_range) in * + by (apply has_flag_range_land'; auto; extract_ok_scalar) + | H : _ = (_,_) |- _ => progress (inversion H; subst) + | H : invert_shift _ _ = Some _ |- _ => + apply invert_shift_correct in H; [|extract_ok_scalar]; + rewrite <-H + | H : has_range ?r (?f ?x ?y) |- context [?f ?y ?x] => + replace (f y x) with (f x y) by solve_commutative_replace + | _ => rewrite has_word_range_mod_small + by (eapply has_range_loosen; + [apply has_range_interp_scalar; extract_ok_scalar|]; + assumption) + | _ => rewrite interp_cast_noop by (cbn [has_range fst snd] in *; split; lia) + | _ => rewrite interp_cast2_noop by (cbn [has_range fst snd] in *; split; lia) + | _ => reflexivity + end. + Qed. + + Lemma of_straightline_correct {t} (e : expr t) : + ok_expr e -> + @interp interp_cast _ (of_straightline dummy_arrow consts e) + = Straightline.expr.interp (interp_ident:=@ident.interp) (interp_cast:=interp_cast) e. + Proof. + induction 1; cbn [of_straightline]; intros; + repeat match goal with + | _ => progress cbn [Straightline.expr.interp] + | _ => erewrite of_straightline_ident_correct + by (cbv [range_ok is_tighter_than_bool_range_type]; + eauto using in_word_range_word_range; + try apply andb_true_iff; auto) + | _ => rewrite interp_cast_noop by eauto using has_range_loosen, ident_interp_has_range + | _ => rewrite interp_cast2_noop by eauto using has_range_loosen, ident_interp_has_range + | H : forall y, has_range _ y -> interp _ = _ |- _ => rewrite H by eauto using has_range_loosen, ident_interp_has_range + | _ => reflexivity + end. + Qed. + End proofs. + + Section no_interp_cast. + Context (dummy_arrow : forall s d, type.interp (s -> d)%ctype) (consts : list Z) + (consts_ok : forall x, In x consts -> 0 <= x <= wordmax - 1). + + Local Arguments interp _ {_} _. + Local Arguments interp_scalar _ {_} _. + + Local Ltac tighter_than_to_le := + repeat match goal with + | _ => progress (cbv [is_tighter_than_bool] in * ) + | _ => rewrite andb_true_iff in * + | H : _ /\ _ |- _ => destruct H + end; Z.ltb_to_lt. + + Lemma replace_interp_cast_scalar {t} (x : scalar t) interp_cast interp_cast' + (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x) + (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) : + ok_scalar x -> + interp_scalar interp_cast x = interp_scalar interp_cast' x. + Proof. + induction 1; cbn [interp_scalar Straightline.expr.interp_scalar]; + repeat match goal with + | _ => progress (cbv [has_range interp_cast2] in * ) + | _ => progress tighter_than_to_le + | H : ok_scalar _ |- _ => apply (has_range_interp_scalar (interp_cast_correct:=interp_cast_correct)) in H + | _ => rewrite <-IHok_scalar + | _ => rewrite interp_cast_correct by omega + | _ => rewrite interp_cast'_correct by omega + | _ => congruence + end. + Qed. + + Lemma replace_interp_cast {t} (e : expr t) interp_cast interp_cast' + (interp_cast_correct : forall r x, lower r <= x <= upper r -> interp_cast r x = x) + (interp_cast'_correct : forall r x, lower r <= x <= upper r -> interp_cast' r x = x) : + ok_expr consts e -> + interp interp_cast (of_straightline dummy_arrow consts e) = + interp interp_cast' (of_straightline dummy_arrow consts e). + Proof. + induction 1; intros; cbn [of_straightline interp]. + { apply replace_interp_cast_scalar; auto. } + { erewrite !of_straightline_ident_correct by (eauto; cbv [range_ok]; apply in_word_range_word_range). + rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. + eauto using ident_interp_has_range. } + { erewrite !of_straightline_ident_correct by + (eauto; try solve [cbv [range_ok]; split; auto using in_word_range_word_range]; + cbv [is_tighter_than_bool_range_type]; apply andb_true_iff; split; auto). + rewrite replace_interp_cast_scalar with (interp_cast'0:=interp_cast') by auto. + eauto using ident_interp_has_range. } + Qed. + End no_interp_cast. +*) + End with_wordmax. +(* + Definition of_Expr {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) + (var : type -> Type) (x : var s) dummy_arrow : @Straightline.expr.expr var ident d := + @of_straightline log2wordmax var dummy_arrow consts _ (Straightline.of_Expr e var x dummy_arrow). +*) + Definition interp_cast_mod w r x := if (lower r =? 0) + then if (upper r =? 2^w - 1) + then x mod (2^w) + else if (upper r =? 1) + then x mod 2 + else x + else x. + + Lemma interp_cast_mod_correct w r x : + lower r <= x <= upper r -> + interp_cast_mod w r x = x. + Proof. + cbv [interp_cast_mod]. + intros; break_match; rewrite ?andb_true_iff in *; intuition; Z.ltb_to_lt; + apply Z.mod_small; omega. + Qed. +(* + Lemma of_Expr_correct {s d} (log2wordmax : Z) (consts : list Z) (e : Expr (s -> d)) + (e' : (type.interp s -> Uncurried.expr.expr d)) + (x : type.interp s) dummy_arrow : + e type.interp = Abs e' -> + 1 < log2wordmax -> + log2wordmax mod 2 = 0 -> + Straightline.expr.ok_expr (e' x) -> + (forall x0 : Z, In x0 consts -> 0 <= x0 <= 2 ^ log2wordmax - 1) -> + ok_expr log2wordmax consts + (of_uncurried (dummy_arrow:=dummy_arrow) (depth (fun _ : type => unit) (fun _ : type => tt) (e _)) (e' x)) -> + (depth type.interp (@DefaultValue.type.default) (e' x) <= depth (fun _ : type => unit) (fun _ : type => tt) (e _))%nat -> + @interp log2wordmax (interp_cast_mod log2wordmax) _ (of_Expr log2wordmax consts e type.interp x dummy_arrow) = @Uncurried.expr.interp _ (@ident.interp) _ (e type.interp) x. + Proof. + intro He'; intros; cbv [of_Expr Straightline.of_Expr]. + rewrite He'; cbn [invert_Abs expr.interp]. + assert (forall r z, lower r <= z <= upper r -> ident.cast ident.cast_outside_of_range r z = z) as interp_cast_correct. + { cbv [ident.cast]; intros; break_match; rewrite ?andb_true_iff, ?andb_false_iff in *; intuition; Z.ltb_to_lt; omega. } + erewrite replace_interp_cast with (interp_cast':=ident.cast ident.cast_outside_of_range) by auto using interp_cast_mod_correct. + rewrite of_straightline_correct by auto. + erewrite Straightline.expr.of_uncurried_correct by eassumption. + reflexivity. + Qed. +*) + Notation LetInAppIdentZ S D r eidc x f + := (expr.LetIn + (A:=type.base (base.type.type_base base.type.Z)) + (B:=type.base D) + (expr.App + (s:=type.base (base.type.type_base base.type.Z)) + (d:=type.base (base.type.type_base base.type.Z)) + (expr.Ident (ident.Z_cast r)) + (expr.App + (s:=type.base S) + (d:=type.base (base.type.type_base base.type.Z)) + eidc + x)) + f). + Notation LetInAppIdentZZ S D r eidc x f + := (expr.LetIn + (A:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + (B:=type.base D) + (expr.App + (s:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + (d:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + (expr.Ident (ident.Z_cast2 r)) + (expr.App + (s:=type.base S) + (d:=type.base (base.type.prod (base.type.type_base base.type.Z) (base.type.type_base base.type.Z))) + eidc + x)) + f). + Module Notations. + Import PrintingNotations. + (*Import Straightline.expr.*) + + Local Open Scope expr_scope. + Local Notation "'tZ'" := (base.type.type_base base.type.Z). + Notation "'RegZero'" := (expr.Ident (ident.Literal 0)). + Notation "$ x" := (#(ident.Z_cast uint256) @ (#ident.fst @ (#(ident.Z_cast2 (uint256,bool)%core) @ (expr.Var x)))) (at level 10, format "$ x"). + Notation "$ x" := (#(ident.Z_cast uint128) @ (#ident.fst @ (#(ident.Z_cast2 (uint128,bool)%core) @ (expr.Var x)))) (at level 10, format "$ x"). + Notation "$ x ₁" := (#(ident.Z_cast uint256) @ (#ident.fst @ (expr.Var x))) (at level 10, format "$ x ₁"). + Notation "$ x ₂" := (#(ident.Z_cast uint256) @ (#ident.snd @ (expr.Var x))) (at level 10, format "$ x ₂"). + Notation "$ x" := (#(ident.Z_cast uint256) @ (expr.Var x)) (at level 10, format "$ x"). + Notation "$ x" := (#(ident.Z_cast uint128) @ (expr.Var x)) (at level 10, format "$ x"). + Notation "$ x" := (#(ident.Z_cast bool) @ (expr.Var x)) (at level 10, format "$ x"). + Notation "carry{ $ x }" := (#(ident.Z_cast bool) @ (#ident.snd @ (#(ident.Z_cast2 (uint256, bool)%core) @ (expr.Var x)))) + (at level 10, format "carry{ $ x }"). + Notation "Lower{ x }" := (#(ident.Z_cast uint128) @ (#(ident.Z_land 340282366920938463463374607431768211455) @ x)) + (at level 10, format "Lower{ x }"). + Notation "f @( y , x1 , x2 ); g " + := (LetInAppIdentZZ _ _ (uint256, bool)%core f (x1, x2) (fun y => g)) + (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). + Notation "f @( y , x1 , x2 , x3 ); g " + := (LetInAppIdentZZ _ _ (uint256, bool)%core f (#ident.pair @ (#ident.pair @ x1 @ x2) @ x3) (fun y => g)) + (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). + Notation "f @( y , x1 , x2 , x3 ); '#128' g " + := (LetInAppIdentZZ _ _ (uint128, bool)%core f (#ident.pair @ (#ident.pair @ x1 @ x2) @ x3) (fun y => g)) + (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '#128' '//' g "). + Notation "f @( y , x1 , x2 ); g " + := (LetInAppIdentZ _ _ uint256 f (#ident.pair @ x1 @ x2) (fun y => g)) + (at level 10, g at level 200, format "f @( y , x1 , x2 ); '//' g "). + Notation "f @( y , x1 , x2 , x3 ); g " + := (LetInAppIdentZ _ _ uint256 f (#ident.pair @ (#ident.pair @ x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "f @( y , x1 , x2 , x3 ); '//' g "). + (* special cases for when the ident constructor takes a constant argument *) + Notation "add@( y , x1 , x2 , n ); g" + := (LetInAppIdentZZ _ _ (uint256, bool) (#(ident.fancy_add 256 n)) (#ident.pair @ x1 x2) (fun y => g)) + (at level 10, g at level 200, format "add@( y , x1 , x2 , n ); '//' g"). + Notation "addc@( y , x1 , x2 , x3 , n ); g" + := (LetInAppIdentZZ _ _ (uint256, bool) (#(ident.fancy_addc 256 n)) (#ident.pair @ (#ident.pair @ x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '//' g"). + Notation "addc@( y , x1 , x2 , x3 , n ); '#128' g" + := (LetInAppIdentZZ _ _ (uint128, bool) (#(ident.fancy_addc 256 n)) (#ident.pair @ (#ident.pair @ x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "addc@( y , x1 , x2 , x3 , n ); '#128' '//' g"). + Notation "sub@( y , x1 , x2 , n ); g" + := (LetInAppIdentZZ _ _ (uint256, bool) (#(ident.fancy_sub 256 n)) (#ident.pair @ x1 x2) (fun y => g)) + (at level 10, g at level 200, format "sub@( y , x1 , x2 , n ); '//' g"). + Notation "subb@( y , x1 , x2 , x3 , n ); g" + := (LetInAppIdentZZ _ _ (uint256, bool) (#(ident.fancy_subb 256 n)) (#ident.pair @ (#ident.pair @ x1 x2) x3) (fun y => g)) + (at level 10, g at level 200, format "subb@( y , x1 , x2 , x3 , n ); '//' g"). + Notation "rshi@( y , x1 , x2 , n ); g" + := (LetInAppIdentZ _ _ _ (#(ident.fancy_rshi 256 n)) (#ident.pair @ x1 x2) (fun y => g)) + (at level 10, g at level 200, format "rshi@( y , x1 , x2 , n ); '//' g "). + (*Notation "'ret' $ x" := (Scalar (expr.Var x)) (at level 10, format "'ret' $ x").*) + Notation "( x , y )" := (#ident.pair @ x @ y) (at level 10, left associativity). + End Notations. +(* + Module Tactics. + Ltac ok_expr_step' := + match goal with + | _ => assumption + | |- _ <= _ <= _ \/ @eq zrange _ _ => + right; lazy; try split; congruence + | |- _ <= _ <= _ \/ @eq zrange _ _ => + left; lazy; try split; congruence + | |- context [PreFancy.ok_ident] => constructor + | |- context [PreFancy.ok_scalar] => constructor; try omega + | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] + | |- context [PreFancy.is_halved] => constructor + | |- context [PreFancy.in_word_range] => lazy; reflexivity + | |- context [PreFancy.in_flag_range] => lazy; reflexivity + | |- context [PreFancy.get_range] => + cbn [PreFancy.get_range lower upper fst snd ZRange.map] + | x : type.interp (type.prod _ _) |- _ => destruct x + | |- (_ <=? _)%zrange = true => + match goal with + | |- context [PreFancy.get_range_var] => + cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower] in *; cbn; + apply andb_true_iff; split; apply Z.leb_le + | _ => lazy + end; omega || reflexivity + | |- @eq zrange _ _ => lazy; reflexivity + | |- _ <= _ => omega + | |- _ <= _ <= _ => omega + end; intros. + + Ltac ok_expr_step := + match goal with + | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' + end; intros; cbn [Nat.max]. + End Tactics. + *) + Notation interp w := (@expr.interp base.type ident.ident base.interp (@ident.gen_interp (PreFancy.interp_cast_mod w))). + Notation Interp w := (@expr.Interp base.type ident.ident base.interp (@ident.gen_interp (PreFancy.interp_cast_mod w))). +End PreFancy. + +Module Fancy. + (*Import Straightline.expr.*) + + Module CC. + Inductive code : Type := + | C : code + | M : code + | L : code + | Z : code + . + + Record state := + { cc_c : bool; cc_m : bool; cc_l : bool; cc_z : bool }. + + Definition code_dec (x y : code) : {x = y} + {x <> y}. + Proof. destruct x, y; try apply (left eq_refl); right; congruence. Defined. + + Definition update (to_write : list code) (result : BinInt.Z) (cc_spec : code -> BinInt.Z -> bool) (old_state : state) + : state := + {| + cc_c := if (In_dec code_dec C to_write) + then cc_spec C result + else old_state.(cc_c); + cc_m := if (In_dec code_dec M to_write) + then cc_spec M result + else old_state.(cc_m); + cc_l := if (In_dec code_dec L to_write) + then cc_spec L result + else old_state.(cc_l); + cc_z := if (In_dec code_dec Z to_write) + then cc_spec Z result + else old_state.(cc_z) + |}. + + End CC. + + Record instruction := + { + num_source_regs : nat; + writes_conditions : list CC.code; + spec : tuple Z num_source_regs -> CC.state -> Z + }. + + Section expr. + Context {name : Type} (name_eqb : name -> name -> bool) (wordmax : Z) (cc_spec : CC.code -> Z -> bool). + + Inductive expr := + | Ret : name -> expr + | Instr (i : instruction) + (rd : name) (* destination register *) + (args : tuple name i.(num_source_regs)) (* source registers *) + (cont : expr) (* next line *) + : expr + . + + Fixpoint interp (e : expr) (cc : CC.state) (ctx : name -> Z) : Z := + match e with + | Ret n => ctx n + | Instr i rd args cont => + let result := i.(spec) (Tuple.map ctx args) cc in + let new_cc := CC.update i.(writes_conditions) result cc_spec cc in + let new_ctx := (fun n : name => if name_eqb n rd then result mod wordmax else ctx n) in + interp cont new_cc new_ctx + end. + End expr. + + Section ISA. + Import CC. + + (* For the C flag, we have to consider cases with a negative result (like the one returned by an underflowing borrow). + In these cases, we want to set the C flag to true. *) + Definition cc_spec (x : CC.code) (result : BinInt.Z) : bool := + match x with + | CC.C => if result <? 0 then true else Z.testbit result 256 + | CC.M => Z.testbit result 255 + | CC.L => Z.testbit result 0 + | CC.Z => result =? 0 + end. + + Local Definition lower128 x := (Z.land x (Z.ones 128)). + Local Definition upper128 x := (Z.shiftr x 128). + Local Notation "x '[C]'" := (if x.(cc_c) then 1 else 0) (at level 20). + Local Notation "x '[M]'" := (if x.(cc_m) then 1 else 0) (at level 20). + Local Notation "x '[L]'" := (if x.(cc_l) then 1 else 0) (at level 20). + Local Notation "x '[Z]'" := (if x.(cc_z) then 1 else 0) (at level 20). + Local Notation "'int'" := (BinInt.Z). + Local Notation "x << y" := ((x << y) mod (2^256)) : Z_scope. (* truncating left shift *) + + + (* Note: In the specification document, argument order gets a bit + confusing. Like here, r0 is always the first argument "source 0" + and r1 the second. But the specification of MUL128LU is: + (R[RS1][127:0] * R[RS0][255:128]) + + while the specification of SUB is: + (R[RS0] - shift(R[RS1], imm)) + + In the SUB case, r0 is really treated the first argument, but in + MUL128LU the order seems to be reversed; rather than low-high, we + take the high part of the first argument r0 and the low parts of + r1. This is also true for MUL128UL. *) + + Definition ADD (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 + (r1 << imm)) + |}. + + Definition ADDC (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 + (r1 << imm) + cc[C]) + |}. + + Definition SUB (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 - (r1 << imm)) + |}. + + Definition SUBC (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [C; M; L; Z]; + spec := (fun '(r0, r1) cc => + r0 - (r1 << imm) - cc[C]) + |}. + + + Definition MUL128LL : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (lower128 r0) * (lower128 r1)) + |}. + + Definition MUL128LU : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (lower128 r1) * (upper128 r0)) (* see note *) + |}. + + Definition MUL128UL : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (upper128 r1) * (lower128 r0)) (* see note *) + |}. + + Definition MUL128UU : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (upper128 r0) * (upper128 r1)) + |}. + + (* Note : Unlike the other operations, the output of RSHI is + truncated in the specification. This is not strictly necessary, + since the interpretation function truncates the output + anyway. However, it is useful to make the definition line up + exactly with Z.rshi. *) + Definition RSHI (imm : int) : instruction := + {| + num_source_regs := 2; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1) cc => + (((2^256 * r0) + r1) >> imm) mod (2^256)) + |}. + + Definition SELC : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[C] =? 1 then r0 else r1) + |}. + + Definition SELM : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[M] =? 1 then r0 else r1) + |}. + + Definition SELL : instruction := + {| + num_source_regs := 2; + writes_conditions := []; + spec := (fun '(r0, r1) cc => + if cc[L] =? 1 then r0 else r1) + |}. + + (* TODO : treat the MOD register specially, like CC *) + Definition ADDM : instruction := + {| + num_source_regs := 3; + writes_conditions := [M; L; Z]; + spec := (fun '(r0, r1, MOD) cc => + let ra := r0 + r1 in + if ra >=? MOD + then ra - MOD + else ra) + |}. + + End ISA. + + Module Registers. + Inductive register : Type := + | r0 : register + | r1 : register + | r2 : register + | r3 : register + | r4 : register + | r5 : register + | r6 : register + | r7 : register + | r8 : register + | r9 : register + | r10 : register + | r11 : register + | r12 : register + | r13 : register + | r14 : register + | r15 : register + | r16 : register + | r17 : register + | r18 : register + | r19 : register + | r20 : register + | r21 : register + | r22 : register + | r23 : register + | r24 : register + | r25 : register + | r26 : register + | r27 : register + | r28 : register + | r29 : register + | r30 : register + | RegZero : register (* r31 *) + | RegMod : register + . + + Definition reg_dec (x y : register) : {x = y} + {x <> y}. + Proof. destruct x, y; try (apply left; congruence); right; congruence. Defined. + Definition reg_eqb x y := if reg_dec x y then true else false. + + Lemma reg_eqb_neq x y : x <> y -> reg_eqb x y = false. + Proof. cbv [reg_eqb]; break_match; congruence. Qed. + Lemma reg_eqb_refl x : reg_eqb x x = true. + Proof. cbv [reg_eqb]; break_match; congruence. Qed. + End Registers. + + Section of_prefancy. + Local Notation cexpr := (@Compilers.expr.expr base.type ident.ident). + Context (name : Type) (name_succ : name -> name) (error : name) (consts : Z -> option name). + + Fixpoint base_var (t : base.type) : Type := + match t with + | base.type.Z => name + | base.type.prod a b => base_var a * base_var b + | _ => unit + end. + Fixpoint var (t : type.type base.type) : Type := + match t with + | type.base t => base_var t + | type.arrow s d => var s -> var d + end. + Fixpoint base_error {t} : base_var t + := match t with + | base.type.Z => error + | base.type.prod A B => (@base_error A, @base_error B) + | _ => tt + end. + Fixpoint make_error {t} : var t + := match t with + | type.base _ => base_error + | type.arrow s d => fun _ => @make_error d + end. + + Fixpoint of_prefancy_scalar {t} (s : @cexpr var t) : var t + := match s in expr.expr t return var t with + | Compilers.expr.Var t v => v + | expr.App s d f x => @of_prefancy_scalar _ f (@of_prefancy_scalar _ x) + | expr.Ident t idc + => match idc in ident.ident t return var t with + | ident.Literal base.type.Z v => match consts v with + | Some n => n + | None => error + end + | ident.pair A B => fun a b => (a, b)%core + | ident.fst A B => fun v => fst v + | ident.snd A B => fun v => snd v + | ident.Z_cast _ => fun v => v + | ident.Z_cast2 _ => fun v => v + | _ => make_error + end + | expr.Abs s d f => make_error + | expr.LetIn A B x f => make_error + end%expr_pat%etype. + + (* Note : some argument orders are reversed for MUL128LU, MUL128UL, SELC, SELM, and SELL *) + Local Notation tZ := base.type.Z. + Definition of_prefancy_ident {s d : base.type} (idc : ident.ident (s -> d)) + : @cexpr var s -> option {i : instruction & tuple name i.(num_source_regs) } := + match idc in ident.ident t return match t return Type with + | type.arrow (type.base s) (type.base d) + => @cexpr var s + | _ => unit + end + -> option {i : instruction & tuple name i.(num_source_regs) } + with + | ident.fancy_add log2wordmax imm + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (ADD imm) (of_prefancy_scalar args)) + else None + | ident.fancy_addc log2wordmax imm + => fun args : @cexpr var (tZ * tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (ADDC imm) (of_prefancy_scalar ((#ident.snd @ (#ident.fst @ args)), (#ident.snd @ args)))) + else None + | ident.fancy_sub log2wordmax imm + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (SUB imm) (of_prefancy_scalar args)) + else None + | ident.fancy_subb log2wordmax imm + => fun args : @cexpr var (tZ * tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (SUBC imm) (of_prefancy_scalar ((#ident.snd @ (#ident.fst @ args)), (#ident.snd @ args)))) + else None + | ident.fancy_mulll log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128LL (of_prefancy_scalar args)) + else None + | ident.fancy_mullh log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128LU (of_prefancy_scalar ((#ident.snd @ args), (#ident.fst @ args)))) + else None + | ident.fancy_mulhl log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128UL (of_prefancy_scalar ((#ident.snd @ args), (#ident.fst @ args)))) + else None + | ident.fancy_mulhh log2wordmax + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ MUL128UU (of_prefancy_scalar args)) + else None + | ident.fancy_rshi log2wordmax imm + => fun args : @cexpr var (tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ (RSHI imm) (of_prefancy_scalar args)) + else None + | ident.fancy_selc + => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ SELC (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) + | ident.fancy_selm log2wordmax + => fun args : @cexpr var (tZ * tZ * tZ) => + if Z.eqb log2wordmax 256 + then Some (existT _ SELM (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) + else None + | ident.fancy_sell + => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ SELL (of_prefancy_scalar ((#ident.snd @ args), (#ident.snd @ (#ident.fst @ args))))) + | ident.fancy_addm + => fun args : @cexpr var (tZ * tZ * tZ) => Some (existT _ ADDM (of_prefancy_scalar args)) + | _ => fun _ => None + end. + + Local Notation "x <- y ; f" := (match y with Some x => f | None => Ret error end). + Definition of_prefancy_step + (of_prefancy : forall (next_name : name) {t} (e : @cexpr var t), @expr name) + (next_name : name) {t} (e : @cexpr var t) : @expr name + := let default _ := (e' <- type.try_transport base.try_make_transport_cps (@cexpr var) t tZ e; + Ret (of_prefancy_scalar e')) in + match e with + | PreFancy.LetInAppIdentZ s d r eidc x f + => idc <- invert_expr.invert_Ident eidc; + instr_args <- @of_prefancy_ident s tZ idc x; + let i : instruction := projT1 instr_args in + let args : tuple name i.(num_source_regs) := projT2 instr_args in + Instr i next_name args (@of_prefancy (name_succ next_name) _ (f next_name)) + | PreFancy.LetInAppIdentZZ s d r eidc x f + => idc <- invert_expr.invert_Ident eidc; + instr_args <- @of_prefancy_ident s (tZ * tZ) idc x; + let i : instruction := projT1 instr_args in + let args : tuple name i.(num_source_regs) := projT2 instr_args in + Instr i next_name args (@of_prefancy (name_succ next_name) _ (f (next_name, error))) (* we pass the error code as the carry register, because it cannot be read from directly. *) + | _ => default tt + end. + Fixpoint of_prefancy (next_name : name) {t} (e : @cexpr var t) : @expr name + := @of_prefancy_step of_prefancy next_name t e. + End of_prefancy. + + Section allocate_registers. + Context (reg name : Type) (name_eqb : name -> name -> bool) (error : reg). + Fixpoint allocate (e : @expr name) (reg_list : list reg) (name_to_reg : name -> reg) : @expr reg := + match e with + | Ret n => Ret (name_to_reg n) + | Instr i rd args cont => + match reg_list with + | r :: reg_list' => Instr i r (Tuple.map name_to_reg args) (allocate cont reg_list' (fun n => if name_eqb n rd then r else name_to_reg n)) + | nil => Ret error + end + end. + End allocate_registers. + + Definition test_prog : @expr positive := + Instr (ADD (128)) 3%positive (1, 2)%positive + (Instr (ADDC 0) 4%positive (3,1)%positive + (Ret 4%positive)). + + Definition x1 := 2^256 - 1. + Definition x2 := 2^128 - 1. + Definition wordmax := 2^256. + Definition expected := + let r3' := (x1 + (x2 << 128)) in + let r3 := r3' mod wordmax in + let c := r3' / wordmax in + let r4' := (r3 + x1 + c) in + r4' mod wordmax. + Definition actual := + interp Pos.eqb + (2^256) cc_spec test_prog {|CC.cc_c:=false; CC.cc_m:=false; CC.cc_l:=false; CC.cc_z:=false|} + (fun n => if n =? 1%positive + then x1 + else if n =? 2%positive + then x2 + else 0). + Lemma test_prog_ok : expected = actual. + Proof. reflexivity. Qed. + + Definition of_Expr {t} next_name (consts : Z -> option positive) (consts_list : list Z) + (e : expr.Expr t) + (x : type.for_each_lhs_of_arrow (var positive) t) + : positive -> @expr positive := + fun error => + @of_prefancy positive Pos.succ error consts next_name _ (invert_expr.smart_App_curried (e _) x). + +End Fancy. + +Module Prod. + Import Fancy. Import Registers. + + Definition Mul256 (out src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := + Instr MUL128LL out (src1, src2) + (Instr MUL128UL tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr MUL128LU tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) cont)))). + Definition Mul256x256 (out outHigh src1 src2 tmp : register) (cont : Fancy.expr) : Fancy.expr := + Instr MUL128LL out (src1, src2) + (Instr MUL128UU outHigh (src1, src2) + (Instr MUL128UL tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) + (Instr MUL128LU tmp (src1, src2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont))))))). + + Definition MontRed256 lo hi y t1 t2 scratch RegPInv : @Fancy.expr register := + Mul256 y lo RegPInv t1 + (Mul256x256 t1 t2 y RegMod scratch + (Instr (ADD 0) lo (lo, t1) + (Instr (ADDC 0) hi (hi, t2) + (Instr SELC y (RegMod, RegZero) + (Instr (SUB 0) lo (hi, y) + (Instr ADDM lo (lo, RegZero, RegMod) + (Ret lo))))))). + + (* Barrett reduction -- this is only the "reduce" part, excluding the initial multiplication. *) + Definition MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : @Fancy.expr register := + let q1Bottom256 := scratchp1 in + let muSelect := scratchp2 in + let q2 := scratchp3 in + let q2High := scratchp4 in + let q2High2 := scratchp5 in + let q3 := scratchp1 in + let r2 := scratchp2 in + let r2High := scratchp3 in + let maybeM := scratchp1 in + Instr SELM muSelect (RegMuLow, RegZero) + (Instr (RSHI 255) q1Bottom256 (xHigh, x) + (Mul256x256 q2 q2High q1Bottom256 RegMuLow scratchp5 + (Instr (RSHI 255) q2High2 (RegZero, xHigh) + (Instr (ADD 0) q2High (q2High, q1Bottom256) + (Instr (ADDC 0) q2High2 (q2High2, RegZero) + (Instr (ADD 0) q2High (q2High, muSelect) + (Instr (ADDC 0) q2High2 (q2High2, RegZero) + (Instr (RSHI 1) q3 (q2High2, q2High) + (Mul256x256 r2 r2High RegMod q3 scratchp4 + (Instr (SUB 0) muSelect (x, r2) + (Instr (SUBC 0) xHigh (xHigh, r2High) + (Instr SELL maybeM (RegMod, RegZero) + (Instr (SUB 0) q3 (muSelect, maybeM) + (Instr ADDM x (q3, RegZero, RegMod) + (Ret x))))))))))))))). +End Prod. + +Module ProdEquiv. + Import Fancy. Import Registers. + + Definition interp256 := Fancy.interp reg_eqb (2^256) cc_spec. + Lemma interp_step i rd args cont cc ctx : + interp256 (Instr i rd args cont) cc ctx = + let result := spec i (Tuple.map ctx args) cc in + let new_cc := CC.update (writes_conditions i) result cc_spec cc in + let new_ctx := fun n => if reg_eqb n rd then result mod wordmax else ctx n in interp256 cont new_cc new_ctx. + Proof. reflexivity. Qed. + + (* TODO : move *) + Lemma tuple_map_ext {A B} (f g : A -> B) n (t : tuple A n) : + (forall x : A, f x = g x) -> + Tuple.map f t = Tuple.map g t. + Proof. + destruct n; [reflexivity|]; cbn in *. + induction n; cbn in *; intro H; auto; [ ]. + rewrite IHn by assumption. + rewrite H; reflexivity. + Qed. + + Lemma interp_state_equiv e : + forall cc ctx cc' ctx', + cc = cc' -> (forall r, ctx r = ctx' r) -> + interp256 e cc ctx = interp256 e cc' ctx'. + Proof. + induction e; intros; subst; cbn; [solve[auto]|]. + apply IHe; rewrite tuple_map_ext with (g:=ctx') by auto; + [reflexivity|]. + intros; break_match; auto. + Qed. + Lemma cc_overwrite_full x1 x2 l1 cc : + CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec (CC.update l1 x1 cc_spec cc) = CC.update [CC.C; CC.M; CC.L; CC.Z] x2 cc_spec cc. + Proof. + cbv [CC.update]. cbn [CC.cc_c CC.cc_m CC.cc_l CC.cc_z]. + break_match; try match goal with H : ~ In _ _ |- _ => cbv [In] in H; tauto end. + reflexivity. + Qed. + + Lemma tuple_map_ext_In {A B} (f g : A -> B) n (t : tuple A n) : + (forall x, In x (to_list n t) -> f x = g x) -> + Tuple.map f t = Tuple.map g t. + Proof. + destruct n; [reflexivity|]; cbn in *. + induction n; cbn in *; intro H; auto; [ ]. + destruct t. + rewrite IHn by auto using in_cons. + rewrite H; auto using in_eq. + Qed. + + Definition value_unused r e : Prop := + forall x cc ctx, interp256 e cc ctx = interp256 e cc (fun r' => if reg_eqb r' r then x else ctx r'). + + Lemma value_unused_skip r i rd args cont (Hcont: value_unused r cont) : + r <> rd -> + (~ In r (Tuple.to_list _ args)) -> + value_unused r (Instr i rd args cont). + Proof. + cbv [value_unused] in *; intros. + rewrite !interp_step; cbv zeta. + rewrite Hcont with (x:=x). + match goal with |- ?lhs = ?rhs => + match lhs with context [Tuple.map ?f ?t] => + match rhs with context [Tuple.map ?g ?t] => + rewrite (tuple_map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) + end end end. + apply interp_state_equiv; [ congruence | ]. + { intros; cbv [reg_eqb] in *; break_match; congruence. } + Qed. + + Lemma value_unused_overwrite r i args cont : + (~ In r (Tuple.to_list _ args)) -> + value_unused r (Instr i r args cont). + Proof. + cbv [value_unused]; intros; rewrite !interp_step; cbv zeta. + match goal with |- ?lhs = ?rhs => + match lhs with context [Tuple.map ?f ?t] => + match rhs with context [Tuple.map ?g ?t] => + rewrite (tuple_map_ext_In f g) by (intros; cbv [reg_eqb]; break_match; congruence) + end end end. + apply interp_state_equiv; [ congruence | ]. + { intros; cbv [reg_eqb] in *; break_match; congruence. } + Qed. + + Lemma value_unused_ret r r' : + r <> r' -> + value_unused r (Ret r'). + Proof. + cbv - [reg_dec]; intros. + break_match; congruence. + Qed. + + Ltac remember_results := + repeat match goal with |- context [(spec ?i ?args ?flags) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (spec i args flags) as x eqn:Heqx; + remember (x mod w) as y + end. + + Ltac do_interp_step := + rewrite interp_step; cbn - [interp spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_results. + + Lemma interp_Mul256 out src1 src2 tmp tmp2 cont cc ctx: + out <> src1 -> + out <> src2 -> + out <> tmp -> + out <> tmp2 -> + src1 <> src2 -> + src1 <> tmp -> + src1 <> tmp2 -> + src2 <> tmp -> + src2 <> tmp2 -> + tmp <> tmp2 -> + value_unused tmp cont -> + value_unused tmp2 cont -> + interp256 (Prod.Mul256 out src1 src2 tmp cont) cc ctx = + interp256 ( + Instr MUL128LU tmp (src1, src2) + (Instr MUL128UL tmp2 (src1, src2) + (Instr MUL128LL out (src1, src2) + (Instr (ADD 128) out (out, tmp2) + (Instr (ADD 128) out (out, tmp) cont))))) cc ctx. + Proof. + intros; cbv [Prod.Mul256]. + repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU ADD] in * ). + + match goal with H : value_unused tmp _ |- _ => erewrite H end. + match goal with H : value_unused tmp2 _ |- _ => erewrite H end. + apply interp_state_equiv. + { rewrite !cc_overwrite_full. + f_equal. subst. lia. } + { intros; cbv [reg_eqb]. + repeat (break_match_step ltac:(fun _ => idtac); try congruence); reflexivity. } + Qed. + + Lemma interp_Mul256x256 out outHigh src1 src2 tmp tmp2 cont cc ctx: + out <> src1 -> + out <> outHigh -> + out <> src2 -> + out <> tmp -> + out <> tmp2 -> + outHigh <> src1 -> + outHigh <> src2 -> + outHigh <> tmp -> + outHigh <> tmp2 -> + src1 <> src2 -> + src1 <> tmp -> + src1 <> tmp2 -> + src2 <> tmp -> + src2 <> tmp2 -> + tmp <> tmp2 -> + value_unused tmp cont -> + value_unused tmp2 cont -> + interp256 (Prod.Mul256x256 out outHigh src1 src2 tmp cont) cc ctx = + interp256 ( + Instr MUL128LL out (src1, src2) + (Instr MUL128LU tmp (src1, src2) + (Instr MUL128UL tmp2 (src1, src2) + (Instr MUL128UU outHigh (src1, src2) + (Instr (ADD 128) out (out, tmp2) + (Instr (ADDC (-128)) outHigh (outHigh, tmp2) + (Instr (ADD 128) out (out, tmp) + (Instr (ADDC (-128)) outHigh (outHigh, tmp) cont)))))))) cc ctx. + Proof. + intros; cbv [Prod.Mul256x256]. + repeat (do_interp_step; cbn [spec MUL128LL MUL128UL MUL128LU MUL128UU ADD ADDC] in * ). + + match goal with H : value_unused tmp _ |- _ => erewrite H end. + match goal with H : value_unused tmp2 _ |- _ => erewrite H end. + apply interp_state_equiv. + { rewrite !cc_overwrite_full. + f_equal. + subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. + lia. } + { intros; cbv [reg_eqb]. + repeat (break_match_step ltac:(fun _ => idtac); try congruence); try reflexivity; [ ]. + subst. cbn - [Z.add Z.modulo Z.testbit Z.mul Z.shiftl Fancy.lower128 Fancy.upper128]. + lia. } + Qed. + + Lemma mulll_comm rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LL rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma mulhh_comm rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UU rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma mullh_mulhl rd x y cont cc ctx : + ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128LU rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr Fancy.MUL128UL rd (y, x) cont) cc ctx. + Proof. rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.mul_comm. reflexivity. Qed. + + Lemma add_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADD 0) rd (y, x) cont) cc ctx. + Proof. + intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite Z.add_comm. + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. + Qed. + + Lemma addc_comm rd x y cont cc ctx : + 0 <= ctx x < 2^256 -> + 0 <= ctx y < 2^256 -> + ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (x, y) cont) cc ctx = ProdEquiv.interp256 (Fancy.Instr (Fancy.ADDC 0) rd (y, x) cont) cc ctx. + Proof. + intros; rewrite !ProdEquiv.interp_step. cbn - [Fancy.interp]. rewrite (Z.add_comm (ctx x)). + rewrite !(Z.mod_small (ctx _)) by (cbn in *; omega). reflexivity. + Qed. + + (* Tactics to help prove that something in Fancy is line-by-line equivalent to something in PreFancy *) + Ltac push_value_unused := + repeat match goal with + | |- ~ In _ _ => cbn; intuition; congruence + | _ => apply ProdEquiv.value_unused_overwrite + | _ => apply ProdEquiv.value_unused_skip; [ | congruence | ] + | _ => apply ProdEquiv.value_unused_ret; congruence + end. + + Ltac remember_single_result := + match goal with |- context [(Fancy.spec ?i ?args ?cc) mod ?w] => + let x := fresh "x" in + let y := fresh "y" in + let Heqx := fresh "Heqx" in + remember (Fancy.spec i args cc) as x eqn:Heqx; + remember (x mod w) as y + end. + Ltac step_both_sides := + match goal with |- ProdEquiv.interp256 (Fancy.Instr ?i ?rd1 ?args1 _) _ ?ctx1 = ProdEquiv.interp256 (Fancy.Instr ?i ?rd2 ?args2 _) _ ?ctx2 => + rewrite (ProdEquiv.interp_step i rd1 args1); rewrite (ProdEquiv.interp_step i rd2 args2); + cbn - [Fancy.interp Fancy.spec]; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence; + remember_single_result; + lazymatch goal with + | |- context [Fancy.spec i _ _] => + let Heqa1 := fresh in + let Heqa2 := fresh in + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx1 args1) eqn:Heqa1; + remember (Tuple.map (n:=i.(Fancy.num_source_regs)) ctx2 args2) eqn:Heqa2; + cbn in Heqa1; cbn in Heqa2; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa1 by congruence; + repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl in Heqa2 by congruence; + let a1 := match type of Heqa1 with _ = ?a1 => a1 end in + let a2 := match type of Heqa2 with _ = ?a2 => a2 end in + (fail 1 "arguments to " i " do not match; LHS has " a1 " and RHS has " a2) + | _ => idtac + end + end. +End ProdEquiv. + +(* Lemmas to help prove that a fancy and prefancy expression have the +same meaning -- should be replaced eventually with a proof of fancy +passes in general. *) + +Module Fancy_PreFancy_Equiv. + Import Fancy.Registers. + + Lemma interp_cast_mod_eq w u x: u = 2^w - 1 -> ident.cast (PreFancy.interp_cast_mod w) r[0 ~> u] x = x mod 2^w. + Proof. + cbv [ident.cast PreFancy.interp_cast_mod upper lower]; intros; subst. + rewrite !Z.eqb_refl. + break_innermost_match; Bool.split_andb; Z.ltb_to_lt; Z.rewrite_mod_small; reflexivity. + Qed. + Lemma interp_cast_mod_flag w x: ident.cast (PreFancy.interp_cast_mod w) r[0 ~> 1] x = x mod 2. + Proof. + cbv [ident.cast PreFancy.interp_cast_mod upper lower]. + break_match; Bool.split_andb; Z.ltb_to_lt; Z.rewrite_mod_small; subst; try omega. + f_equal; omega. + Qed. + + Lemma interp_equivZ {s} w u (Hu : u = 2^w-1) i rd regs e cc ctx idc args f : + (Fancy.spec i (Tuple.map ctx regs) cc + = ident.gen_interp (PreFancy.interp_cast_mod w) (t:=type.arrow _ base.type.Z) idc (PreFancy.interp w args)) -> + ( let r := Fancy.spec i (Tuple.map ctx regs) cc in + Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e + (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc) + (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) = + @PreFancy.interp w base.type.Z (f (r mod 2 ^ w))) -> + Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx + = @PreFancy.interp w base.type.Z + (@PreFancy.LetInAppIdentZ s _ (r[0~>2^w-1])%zrange (#idc) args f). + Proof. + cbv zeta; intros spec_eq next_eq. + cbn [Fancy.interp PreFancy.interp]. + cbv [Let_In]. + rewrite next_eq. + cbn in *. + rewrite <-spec_eq. + rewrite interp_cast_mod_eq by omega. + reflexivity. + Qed. + + Lemma interp_equivZZ {s} w (Hw : 2 < 2 ^ w) u (Hu : u = 2^w - 1) i rd regs e cc ctx idc args f : + ((Fancy.spec i (Tuple.map ctx regs) cc) mod 2 ^ w + = fst (ident.gen_interp (PreFancy.interp_cast_mod w) (t:=type.arrow _ (base.type.Z*base.type.Z)) idc (PreFancy.interp w args))) -> + ((if Fancy.cc_spec Fancy.CC.C(Fancy.spec i (Tuple.map ctx regs) cc) then 1 else 0) + = snd (ident.gen_interp (PreFancy.interp_cast_mod w) (t:=type.arrow _ (base.type.Z*base.type.Z)) idc (PreFancy.interp w args)) mod 2) -> + ( let r := Fancy.spec i (Tuple.map ctx regs) cc in + Fancy.interp reg_eqb (2 ^ w) Fancy.cc_spec e + (Fancy.CC.update (Fancy.writes_conditions i) r Fancy.cc_spec cc) + (fun n : register => if reg_eqb n rd then r mod 2 ^ w else ctx n) = + @PreFancy.interp w base.type.Z + (f (r mod 2 ^ w, if (Fancy.cc_spec Fancy.CC.C r) then 1 else 0))) -> + Fancy.interp reg_eqb (2^w) Fancy.cc_spec (Fancy.Instr i rd regs e) cc ctx + = @PreFancy.interp w base.type.Z + (@PreFancy.LetInAppIdentZZ s _ (r[0~>u], r[0~>1])%zrange (#idc) args f). + Proof. + cbv zeta; intros spec_eq1 spec_eq2 next_eq. + cbn [Fancy.interp PreFancy.interp]. + cbv [Let_In]. + cbn [ident.gen_interp]; Prod.eta_expand. + rewrite next_eq. + rewrite interp_cast_mod_eq by omega. + rewrite interp_cast_mod_flag by omega. + cbn -[Fancy.cc_spec] in *. + rewrite <-spec_eq1, <-spec_eq2. + rewrite Z.mod_mod by omega. + reflexivity. + Qed. +End Fancy_PreFancy_Equiv. + +Module Barrett256. + + Definition M := Eval lazy in (2^256-2^224+2^192+2^96-1). + Definition machine_wordsize := 256. + + Derive barrett_red256 + SuchThat (BarrettReduction.rbarrett_red_correctT M machine_wordsize barrett_red256) + As barrett_red256_correct. + Proof. Time solve_rbarrett_red machine_wordsize. Time Qed. + + Definition muLow := Eval lazy in (2 ^ (2 * machine_wordsize) / M) mod (2^machine_wordsize). + (* + Definition barrett_red256_prefancy' := PreFancy.of_Expr machine_wordsize [M; muLow] barrett_red256. + + Derive barrett_red256_prefancy + SuchThat (barrett_red256_prefancy = barrett_red256_prefancy' type.interp) + As barrett_red256_prefancy_eq. + Proof. lazy - [type.interp]; reflexivity. Qed. + *) + + Lemma barrett_reduce_correct_specialized : + forall (xLow xHigh : Z), + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + apply BarrettReduction.barrett_reduce_correct; cbv [machine_wordsize M muLow] in *; + try omega; + try match goal with + | |- context [weight] => intros; cbv [weight]; autorewrite with zsimplify; auto using Z.pow_mul_r with omega + end; lazy; try split; congruence. + Qed. + + (* + (* Note: If this is not factored out, then for some reason Qed takes forever in barrett_red256_correct_full. *) + Lemma barrett_red256_correct_proj2 : + forall xy : type.interp base.interp (base.type.prod base.type.Z base.type.Z), + ZRange.type.option.is_bounded_by + (t:=base.type.prod base.type.Z base.type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + type.app_curried (t:=type.arrow (base.type.prod base.type.Z base.type.Z) base.type.Z) (expr.Interp (@ident.interp) barrett_red256) xy = type.app_curried (t:=type.arrow (base.type.prod base.type.Z base.type.Z) base.type.Z) (fun xy => BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 (fst xy) (snd xy)) xy. + Proof. intros; destruct (barrett_red256_correct xy); assumption. Qed. + Lemma barrett_red256_correct_proj2' : + forall x y : Z, + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + (x, y) = true -> + expr.Interp (@ident.interp) barrett_red256 (x, y) = BarrettReduction.barrett_reduce machine_wordsize M muLow 2 2 x y. + Proof. intros; rewrite barrett_red256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. + *) + Strategy -100 [type.app_curried]. + Lemma barrett_red256_correct_full : + forall (xLow xHigh : Z), + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + PreFancy.Interp 256 barrett_red256 xLow xHigh = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. + rewrite <-barrett_reduce_correct_specialized by assumption. + destruct (barrett_red256_correct (xLow, (xHigh, tt))) as [H1 H2]. + { cbn -[Z.pow]. + rewrite !andb_true_iff. + assert (M < 2^machine_wordsize) by (vm_compute; reflexivity). + repeat apply conj; Z.ltb_to_lt; trivial; omega. } + { etransitivity; [ eapply H2 | ]. (* need Strategy -100 [type.app_curried]. for this to be fast *) + generalize BarrettReduction.barrett_reduce; vm_compute; reflexivity. } + Qed. + + (* + Import PreFancy.Tactics. (* for ok_expr_step *) + Lemma barrett_red256_prefancy_correct : + forall xLow xHigh dummy_arrow, + 0 <= xLow < 2 ^ machine_wordsize -> + 0 <= xHigh < M -> + @PreFancy.interp machine_wordsize (PreFancy.interp_cast_mod machine_wordsize) type.Z (barrett_red256_prefancy (xLow, xHigh) dummy_arrow) = (xLow + 2 ^ machine_wordsize * xHigh) mod M. + Proof. + intros. rewrite barrett_red256_prefancy_eq; cbv [barrett_red256_prefancy']. + erewrite PreFancy.of_Expr_correct. + { apply barrett_red256_correct_full; try assumption; reflexivity. } + { reflexivity. } + { lazy; reflexivity. } + { lazy; reflexivity. } + { repeat constructor. } + { cbv [In M muLow]; intros; intuition; subst; cbv; congruence. } + { let r := (eval compute in (2 ^ machine_wordsize)) in + replace (2^machine_wordsize) with r in * by reflexivity. + cbv [M muLow machine_wordsize] in *. + assert (lower r[0~>1] = 0) by reflexivity. + repeat (ok_expr_step; [ ]). + ok_expr_step. + lazy; congruence. + constructor. + constructor. } + { lazy. omega. } + Qed. + *) + Definition barrett_red256_fancy' (xLow xHigh RegMuLow RegMod RegZero error : positive) := + Fancy.of_Expr 3%positive + (fun z => if z =? muLow then Some RegMuLow else if z =? M then Some RegMod else if z =? 0 then Some RegZero else None) + [M; muLow] + barrett_red256 + (xLow, (xHigh, tt)) + error. + Derive barrett_red256_fancy + SuchThat (forall xLow xHigh RegMuLow RegMod RegZero, + barrett_red256_fancy xLow xHigh RegMuLow RegMod RegZero = barrett_red256_fancy' xLow xHigh RegMuLow RegMod RegZero) + As barrett_red256_fancy_eq. + Proof. + intros. + lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB Fancy.SUBC + Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU + Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. + reflexivity. + Qed. + + Import Fancy.Registers. + + Definition barrett_red256_alloc' xLow xHigh RegMuLow := + fun errorP errorR => + Fancy.allocate register + positive Pos.eqb + errorR + (barrett_red256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) + [r2;r3;r4;r5;r6;r7;r8;r9;r10;r5;r11;r6;r12;r13;r14;r15;r16;r17;r18;r19;r20;r21;r22;r23;r24;r25;r26;r27;r28;r29] + (fun n => if n =? 1000 then xLow + else if n =? 1001 then xHigh + else if n =? 1002 then RegMuLow + else if n =? 1003 then RegMod + else if n =? 1004 then RegZero + else errorR). + Derive barrett_red256_alloc + SuchThat (barrett_red256_alloc = barrett_red256_alloc') + As barrett_red256_alloc_eq. + Proof. + intros. + cbv [barrett_red256_alloc' barrett_red256_fancy]. + cbn. subst barrett_red256_alloc. + reflexivity. + Qed. + + Set Printing Depth 1000. + Import ProdEquiv. + + Local Ltac solve_bounds := + match goal with + | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega + | _ => assumption + end. + + Lemma barrett_red256_alloc_equivalent errorP errorR cc_start_state start_context : + forall x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg, + NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> + 0 <= start_context x < 2^machine_wordsize -> + 0 <= start_context xHigh < 2^machine_wordsize -> + 0 <= start_context RegMuLow < 2^machine_wordsize -> + ProdEquiv.interp256 (barrett_red256_alloc r0 r1 r30 errorP errorR) cc_start_state + (fun r => if reg_eqb r r0 + then start_context x + else if reg_eqb r r1 + then start_context xHigh + else if reg_eqb r r30 + then start_context RegMuLow + else start_context r) + = ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context. + Proof. + intros. + let r := eval compute in (2^machine_wordsize) in + replace (2^machine_wordsize) with r in * by reflexivity. + cbv [Prod.MulMod barrett_red256_alloc]. + + (* Extract proofs that no registers are equal to each other *) + repeat match goal with + | H : NoDup _ |- _ => inversion H; subst; clear H + | H : ~ In _ _ |- _ => cbv [In] in H + | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H + | H : ~ False |- _ => clear H + end. + + step_both_sides. + + (* TODO: To prove equivalence between these two, we need to either relocate the RSHI instructions so they're in the same places or use instruction commutativity to push them down. *) + + Admitted. + + Import Fancy_PreFancy_Equiv. + + Definition interp_equivZZ_256 {s} := + @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + Definition interp_equivZ_256 {s} := + @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + + Local Ltac simplify_op_equiv start_ctx := + cbn - [Fancy.spec (*PreFancy.interp_ident*) ident.gen_interp Fancy.cc_spec Z.shiftl]; + repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; + cbv - [ + Z.rshi Z.cc_m Fancy.CC.cc_m + Z.add_with_get_carry_full Z.add_get_carry_full + Z.sub_get_borrow_full Z.sub_with_get_borrow_full + Z.le Z.lt Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl + Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones + fst snd]; cbn [fst snd]; + try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); + autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; + let r := (eval compute in (2 ^ 256)) in + replace (2^256) with r in * by reflexivity; + repeat match goal with + | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H + | |- context [?x <? 0] => rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) + | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) + | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => + replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring + end. + + Local Ltac solve_nonneg ctx := + match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; + simplify_op_equiv ctx; Z.zero_bounds. + + Local Ltac generalize_result := + let v := fresh "v" in intro v; generalize v; clear v; intro v. + + Local Ltac generalize_result_nonneg ctx := + let v := fresh "v" in + let v_nonneg := fresh "v_nonneg" in + intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. + + Local Ltac step_abs := + match goal with + | [ |- context G[expr.interp ?ident_interp (expr.Abs ?f) ?x] ] + => let G' := context G[expr.interp ident_interp (f x)] in + change G'; cbv beta + end. + Local Ltac step ctx := + repeat step_abs; + match goal with + | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => + apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] + | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZ _ _ _ _ _ _) ] + => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] + | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZZ _ _ _ _ _ _) ] + => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] + end. + + (* TODO: move this lemma to ZUtil *) + Lemma testbit_neg_eq_if x n : + 0 <= n -> + - (2 ^ n) <= x < 2 ^ n -> + Z.b2z (if x <? 0 then true else Z.testbit x n) = - (x / 2 ^ n) mod 2. + Proof. + intros. break_match; Z.ltb_to_lt. + { autorewrite with zsimplify. reflexivity. } + { autorewrite with zsimplify. + rewrite Z.bits_above_pow2 by omega. + reflexivity. } + Qed. + + Lemma prod_barrett_red256_correct : + forall (cc_start_state : Fancy.CC.state) (* starting carry flags *) + (start_context : register -> Z) (* starting register values *) + (x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 extra_reg : register), (* registers to use in computation *) + NoDup [x; xHigh; RegMuLow; scratchp1; scratchp2; scratchp3; scratchp4; scratchp5; extra_reg; RegMod; RegZero] -> (* registers are unique *) + 0 <= start_context x < 2^machine_wordsize -> + 0 <= start_context xHigh < M -> + start_context RegMuLow = muLow -> + start_context RegMod = M -> + start_context RegZero = 0 -> + cc_start_state.(Fancy.CC.cc_m) = (Z.cc_m (2^256) (start_context xHigh) =? 1) -> + let X := start_context x + 2^machine_wordsize * start_context xHigh in + ProdEquiv.interp256 (Prod.MulMod x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5) cc_start_state start_context = X mod M. + Proof. + intros. subst X. + assert (0 <= start_context xHigh < 2^machine_wordsize) by (cbv [M] in *; cbn; omega). + let r := (eval compute in (2 ^ machine_wordsize)) in + replace (2^machine_wordsize) with r in * by reflexivity. + cbv [M muLow] in *. + + rewrite <-barrett_red256_correct_full by auto. + rewrite <-barrett_red256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) + by (auto; cbn; auto with omega). + cbv [ProdEquiv.interp256]. + let r := (eval compute in (2 ^ 256)) in + replace (2^256) with r in * by reflexivity. + cbv [barrett_red256_alloc barrett_red256 expr.Interp]. + + step start_context. + { match goal with H : Fancy.CC.cc_m _ = _ |- _ => rewrite H end. + match goal with |- context [Z.cc_m ?s ?x] => + pose proof (Z.cc_m_small s x ltac:(reflexivity) ltac:(reflexivity) ltac:(omega)); + let H := fresh in + assert (Z.cc_m s x = 1 \/ Z.cc_m s x = 0) as H by omega; + destruct H as [H | H]; rewrite H in * + end; repeat (change (0 =? 1) with false || change (?x =? ?x) with true || cbv beta iota); + break_innermost_match; Z.ltb_to_lt; try congruence. } + apply interp_equivZ_256; [ simplify_op_equiv start_context | ]. (* apply manually instead of using [step] to allow a custom bounds proof *) + { rewrite Z.rshi_correct by omega. + autorewrite with zsimplify_fast. + rewrite Z.shiftr_div_pow2 by omega. + break_innermost_match; Z.ltb_to_lt; try omega. + do 2 f_equal; omega. } + + (* Special case to remember the bound for the output of RSHI *) + let v := fresh "v" in + let v_bound := fresh "v_bound" in + intro v; assert (0 <= v <= 1) as v_bound; [ |generalize v v_bound; clear v v_bound; intros v v_bound]. + { solve_nonneg start_context. autorewrite with zsimplify_fast. + rewrite Z.shiftr_div_pow2 by omega. + rewrite Z.mod_small by admit. + split; [Z.zero_bounds|]. + apply Z.lt_succ_r. + apply Z.div_lt_upper_bound; try lia; admit. } +(* + step start_context. + { rewrite Z.rshi_correct by omega. + rewrite Z.shiftr_div_pow2 by omega. + repeat (f_equal; try ring). } + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; + [ rewrite Z.mod_small with (b:=2) by (rewrite Z.mod_small by omega; omega); (* Here we make use of the bound of RSHI *) + reflexivity + | rewrite Z.mod_small with (b:=2) by (rewrite Z.mod_small by omega; omega); (* Here we make use of the bound of RSHI *) + reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context. + { rewrite Z.rshi_correct by omega. + rewrite Z.shiftr_div_pow2 by omega. + repeat (f_equal; try ring). } + + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + + step start_context. + { reflexivity. } + { autorewrite with zsimplify_fast. + match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite <-testbit_neg_eq_if with (n:=256) by (cbn; omega). + reflexivity. } + step start_context. + { reflexivity. } + { autorewrite with zsimplify_fast. + rewrite Z.mod_small with (a:=(if (if _ <? 0 then true else _) then _ else _)) (b:=2) by (break_innermost_match; omega). + match goal with |- context [?a - ?b - ?c] => replace (a - b - c) with (a - (b + c)) by ring end. + match goal with |- context [?x mod ?m] => pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite <-testbit_neg_eq_if with (n:=256) by (break_innermost_match; cbn; omega). + reflexivity. } + step start_context. + { rewrite Z.bit0_eqb. + match goal with |- context [(?x mod ?m) &' 1] => + replace (x mod m) with (x &' Z.ones 256) by (rewrite Z.land_ones by omega; reflexivity) end. + rewrite <-Z.land_assoc. + rewrite Z.land_ones with (n:=1) by omega. + cbn. + match goal with |- context [?x mod 2] => + let H := fresh in + assert (x mod 2 = 0 \/ x mod 2 = 1) as H + by (pose proof (Z.mod_pos_bound x 2 ltac:(omega)); omega); + destruct H as [H | H]; rewrite H + end; reflexivity. } + step start_context. + { reflexivity. } + { autorewrite with zsimplify_fast. + repeat match goal with |- context [?x mod ?m] => unique pose proof (Z.mod_pos_bound x m ltac:(omega)) end. + rewrite <-testbit_neg_eq_if with (n:=256) by (cbn; omega). + reflexivity. } + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + reflexivity. +*) + Admitted. + + Import PrintingNotations. + Set Printing Width 1000. + Open Scope expr_scope. + Print barrett_red256. + (* +barrett_red256 = fun var : type -> Type => λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + expr_let x0 := SELM (x₂, 0, 26959946667150639793205513449348445388433292963828203772348655992835) in + expr_let x1 := RSHI (0, x₂, 255) in + expr_let x2 := RSHI (x₂, x₁, 255) in + expr_let x3 := 79228162514264337589248983038 *₂₅₆ (uint128)(x2 >> 128) in + expr_let x4 := 79228162514264337589248983038 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in + expr_let x5 := 340282366841710300930663525764514709507 *₂₅₆ (uint128)(x2 >> 128) in + expr_let x6 := 340282366841710300930663525764514709507 *₂₅₆ ((uint128)(x2) & 340282366920938463463374607431768211455) in + expr_let x7 := ADD_256 ((uint256)(((uint128)(x5) & 340282366920938463463374607431768211455) << 128), x6) in + expr_let x8 := ADDC_256 (x7₂, (uint128)(x5 >> 128), x3) in + expr_let x9 := ADD_256 ((uint256)(((uint128)(x4) & 340282366920938463463374607431768211455) << 128), x7₁) in + expr_let x10 := ADDC_256 (x9₂, (uint128)(x4 >> 128), x8₁) in + expr_let x11 := ADD_256 (x2, x10₁) in + expr_let x12 := ADDC_128 (x11₂, 0, x1) in + expr_let x13 := ADD_256 (x0, x11₁) in + expr_let x14 := ADDC_128 (x13₂, 0, x12₁) in + expr_let x15 := RSHI (x14₁, x13₁, 1) in + expr_let x16 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x15 >> 128) in + expr_let x17 := 79228162514264337593543950335 *₂₅₆ (uint128)(x15 >> 128) in + expr_let x18 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in + expr_let x19 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x15) & 340282366920938463463374607431768211455) in + expr_let x20 := ADD_256 ((uint256)(((uint128)(x18) & 340282366920938463463374607431768211455) << 128), x19) in + expr_let x21 := ADDC_256 (x20₂, (uint128)(x18 >> 128), x16) in + expr_let x22 := ADD_256 ((uint256)(((uint128)(x17) & 340282366920938463463374607431768211455) << 128), x20₁) in + expr_let x23 := ADDC_256 (x22₂, (uint128)(x17 >> 128), x21₁) in + expr_let x24 := SUB_256 (x₁, x22₁) in + expr_let x25 := SUBB_256 (x24₂, x₂, x23₁) in + expr_let x26 := SELL (x25₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x27 := SUB_256 (x24₁, x26) in + ADDM (x27₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) + : Expr (type.uncurry (type.type_primitive type.Z -> type.type_primitive type.Z -> type.type_primitive type.Z)) + *) + + Import PreFancy. + Import PreFancy.Notations. + (* +Local Notation "'RegMod'" := (Straightline.expr.Primitive (t:=type.Z) 115792089210356248762697446949407573530086143415290314195533631308867097853951). + Local Notation "'RegMuLow'" := (Straightline.expr.Primitive (t:=type.Z) 26959946667150639793205513449348445388433292963828203772348655992835). + *) + (* + Print barrett_red256_prefancy. +*) + (* + selm@(y, $x₂, RegZero, RegMuLow); + rshi@(y0, RegZero, $x₂,255); + rshi@(y1, $x₂, $x₁,255); + mulhh@(y2, RegMuLow, $y1); + mulhl@(y3, RegMuLow, $y1); + mullh@(y4, RegMuLow, $y1); + mulll@(y5, RegMuLow, $y1); + add@(y6, $y5, $y4, 128); + addc@(y7, carry{$y6}, $y2, $y4, -128); + add@(y8, $y6, $y3, 128); + addc@(y9, carry{$y8}, $y7, $y3, -128); + add@(y10, $y1, $y9, 0); + addc@(y11, carry{$y10}, RegZero, $y0, 0); #128 + add@(y12, $y, $y10, 0); + addc@(y13, carry{$y12}, RegZero, $y11, 0); #128 + rshi@(y14, $y13, $y12,1); + mulhh@(y15, RegMod, $y14); + mullh@(y16, RegMod, $y14); + mulhl@(y17, RegMod, $y14); + mulll@(y18, RegMod, $y14); + add@(y19, $y18, $y17, 128); + addc@(y20, carry{$y19}, $y15, $y17, -128); + add@(y21, $y19, $y16, 128); + addc@(y22, carry{$y21}, $y20, $y16, -128); + sub@(y23, $x₁, $y21, 0); + subb@(y24, carry{$y23}, $x₂, $y22, 0); + sell@(y25, $y24, RegZero, RegMod); + sub@(y26, $y23, $y25, 0); + addm@(y27, $y26, RegZero, RegMod); + ret $y27 + *) +End Barrett256. + +Module Montgomery256. + + Definition N := Eval lazy in (2^256-2^224+2^192+2^96-1). + Definition N':= (115792089210356248768974548684794254293921932838497980611635986753331132366849). + Definition R := Eval lazy in (2^256). + Definition R' := 115792089183396302114378112356516095823261736990586219612555396166510339686400. + Definition machine_wordsize := 256. + + Derive montred256 + SuchThat (MontgomeryReduction.rmontred_correctT N R N' machine_wordsize montred256) + As montred256_correct. + Proof. Time solve_rmontred machine_wordsize. Time Qed. + + (* + Definition montred256_prefancy' := PreFancy.of_Expr machine_wordsize [N;N'] montred256. + + Derive montred256_prefancy + SuchThat (montred256_prefancy = montred256_prefancy' type.interp) + As montred256_prefancy_eq. + Proof. lazy - [type.interp]; reflexivity. Qed. +*) + + Lemma montred'_correct_specialized R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 (lo, hi) = ((lo + R * hi) * R') mod N. + Proof. + intros. + apply MontgomeryReduction.montred'_correct with (T:=lo + R * hi) (R':=R'); + try match goal with + | |- context[R'] => assumption + | |- context [lo] => + try assumption; progress autorewrite with zsimplify cancel_pair; reflexivity + end; lazy; try split; congruence. + Qed. + + (* + (* Note: If this is not factored out, then for some reason Qed takes forever in montred256_correct_full. *) + Lemma montred256_correct_proj2 : + forall xy : type.interp (type.prod type.Z type.Z), + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + expr.Interp (@ident.interp) montred256 xy = app_curried (t:=type.arrow (type.prod type.Z type.Z) type.Z) (MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2) xy. + Proof. intros; destruct (montred256_correct xy); assumption. Qed. + Lemma montred256_correct_proj2' : + forall xy : type.interp (type.prod type.Z type.Z), + ZRange.type.option.is_bounded_by + (t:=type.prod type.Z type.Z) + (Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange, Some r[0 ~> 2 ^ machine_wordsize - 1]%zrange) + xy = true -> + expr.Interp (@ident.interp) montred256 xy = MontgomeryReduction.montred' N R N' (Z.log2 R) 2 2 xy. + Proof. intros; rewrite montred256_correct_proj2 by assumption; unfold app_curried; exact eq_refl. Qed. +*) + Lemma montred256_correct_full R' (R'_correct : Z.equiv_modulo N (R * R') 1) : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + PreFancy.Interp 256 montred256 (lo, hi) = ((lo + R * hi) * R') mod N. + Proof. + intros. + rewrite <-montred'_correct_specialized by assumption. + destruct (montred256_correct ((lo, hi), tt)) as [H2 H3]. + { cbn -[Z.pow]. + rewrite !andb_true_iff. + repeat apply conj; Z.ltb_to_lt; trivial; cbv [R N machine_wordsize] in *; lia. } + { etransitivity; [ eapply H3 | ]. (* need Strategy -100 [type.app_curried]. for this to be fast *) + generalize MontgomeryReduction.montred'; vm_compute; reflexivity. } + Qed. + + (* + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step' := + match goal with + | _ => assumption + | |- _ <= _ <= _ \/ @eq zrange _ _ => + right; lazy; try split; congruence + | |- _ <= _ <= _ \/ @eq zrange _ _ => + left; lazy; try split; congruence + | |- lower r[0~>_]%zrange = 0 => reflexivity + | |- context [PreFancy.ok_ident] => constructor + | |- context [PreFancy.ok_scalar] => constructor; try omega + | |- context [PreFancy.is_halved] => eapply PreFancy.is_halved_constant; [lazy; reflexivity | ] + | |- context [PreFancy.is_halved] => constructor + | |- context [PreFancy.in_word_range] => lazy; reflexivity + | |- context [PreFancy.in_flag_range] => lazy; reflexivity + | |- context [PreFancy.get_range] => + cbn [PreFancy.get_range lower upper fst snd ZRange.map] + | x : type.interp (type.prod _ _) |- _ => destruct x + | |- (_ <=? _)%zrange = true => + match goal with + | |- context [PreFancy.get_range_var] => + cbv [is_tighter_than_bool PreFancy.has_range fst snd upper lower R N] in *; cbn; + apply andb_true_iff; split; apply Z.leb_le + | _ => lazy + end; omega || reflexivity + | |- @eq zrange _ _ => lazy; reflexivity + | |- _ <= _ => cbv [machine_wordsize]; omega + | |- _ <= _ <= _ => cbv [machine_wordsize]; omega + end; intros. + + (* TODO : maybe move these ok_expr tactics somewhere else *) + Ltac ok_expr_step := + match goal with + | |- context [PreFancy.ok_expr] => constructor; cbn [fst snd]; repeat ok_expr_step' + end; intros; cbn [Nat.max].*) + + (* + Lemma montred256_prefancy_correct : + forall (lo hi : Z), + 0 <= lo < R -> 0 <= hi < R -> 0 <= lo + R * hi < R * N -> + @PreFancy.interp machine_wordsize base.type.Z (montred256 _ @ (##lo,##hi)) = ((lo + R * hi) * R') mod N. + Proof. + intros. + + rewrite montred256_prefancy_eq; cbv [montred256_prefancy']. + erewrite PreFancy.of_Expr_correct. + { apply montred256_correct_full; try assumption; reflexivity. } + { reflexivity. } + { lazy; reflexivity. } + { lazy; reflexivity. } + { repeat constructor. } + { cbv [In N N']; intros; intuition; subst; cbv; congruence. } + { assert (340282366920938463463374607431768211455 * 2 ^ 128 <= 2 ^ machine_wordsize - 1) as shiftl_128_ok by (lazy; congruence). + repeat (ok_expr_step; [ ]). + ok_expr_step. + lazy; congruence. + constructor. + constructor. } + { lazy. omega. } + Qed. +*) + + Definition montred256_fancy' (lo hi RegMod RegPInv RegZero error : positive) := + Fancy.of_Expr 3%positive + (fun z => if z =? N then Some RegMod else if z =? N' then Some RegPInv else if z =? 0 then Some RegZero else None) + [N;N'] + montred256 + ((lo, hi)%positive, tt) + error. + Derive montred256_fancy + SuchThat (forall RegMod RegPInv RegZero, + montred256_fancy RegMod RegPInv RegZero = montred256_fancy' RegMod RegPInv RegZero) + As montred256_fancy_eq. + Proof. + intros. + lazy - [Fancy.ADD Fancy.ADDC Fancy.SUB + Fancy.MUL128LL Fancy.MUL128LU Fancy.MUL128UL Fancy.MUL128UU + Fancy.RSHI Fancy.SELC Fancy.SELM Fancy.SELL Fancy.ADDM]. + reflexivity. + Qed. + + Import Fancy.Registers. + + Definition montred256_alloc' lo hi RegPInv := + fun errorP errorR => + Fancy.allocate register + positive Pos.eqb + errorR + (montred256_fancy 1000%positive 1001%positive 1002%positive 1003%positive 1004%positive errorP) + [r2;r3;r4;r5;r6;r7;r8;r9;r10;r11;r12;r13;r14;r15;r16;r17;r18;r19;r20] + (fun n => if n =? 1000 then lo + else if n =? 1001 then hi + else if n =? 1002 then RegMod + else if n =? 1003 then RegPInv + else if n =? 1004 then RegZero + else errorR). + Derive montred256_alloc + SuchThat (montred256_alloc = montred256_alloc') + As montred256_alloc_eq. + Proof. + intros. + cbv [montred256_alloc' montred256_fancy]. + cbn. subst montred256_alloc. + reflexivity. + Qed. + + Import ProdEquiv. + + Local Ltac solve_bounds := + match goal with + | H : ?a = ?b mod ?c |- 0 <= ?a < ?c => rewrite H; apply Z.mod_pos_bound; omega + | _ => assumption + end. + + Lemma montred256_alloc_equivalent errorP errorR cc_start_state start_context : + forall lo hi y t1 t2 scratch RegPInv extra_reg, + NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> + 0 <= start_context lo < R -> + 0 <= start_context hi < R -> + 0 <= start_context RegPInv < R -> + ProdEquiv.interp256 (montred256_alloc r0 r1 r30 errorP errorR) cc_start_state + (fun r => if reg_eqb r r0 + then start_context lo + else if reg_eqb r r1 + then start_context hi + else if reg_eqb r r30 + then start_context RegPInv + else start_context r) + = ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context. + Proof. + intros. cbv [R] in *. + cbv [Prod.MontRed256 montred256_alloc]. + + (* Extract proofs that no registers are equal to each other *) + repeat match goal with + | H : NoDup _ |- _ => inversion H; subst; clear H + | H : ~ In _ _ |- _ => cbv [In] in H + | H : ~ (_ \/ _) |- _ => apply Decidable.not_or in H; destruct H + | H : ~ False |- _ => clear H + end. + + rewrite ProdEquiv.interp_Mul256 with (tmp2:=extra_reg) by (congruence || push_value_unused). + + rewrite mullh_mulhl. step_both_sides. + rewrite mullh_mulhl. step_both_sides. + (* + step_both_sides. + step_both_sides. + + rewrite ProdEquiv.interp_Mul256x256 with (tmp2:=extra_reg) by (congruence || push_value_unused). + + rewrite mulll_comm. step_both_sides. + step_both_sides. + step_both_sides. + rewrite mulhh_comm. step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + + + rewrite add_comm by (cbn; solve_bounds). step_both_sides. + rewrite addc_comm by (cbn; solve_bounds). step_both_sides. + step_both_sides. + step_both_sides. + step_both_sides. + + cbn; repeat progress rewrite ?reg_eqb_neq, ?reg_eqb_refl by congruence. + reflexivity.*) + Admitted. + + Import Fancy_PreFancy_Equiv. + + Definition interp_equivZZ_256 {s} := + @interp_equivZZ s 256 ltac:(cbv; congruence) 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + Definition interp_equivZ_256 {s} := + @interp_equivZ s 256 115792089237316195423570985008687907853269984665640564039457584007913129639935 ltac:(reflexivity). + + Local Ltac simplify_op_equiv start_ctx := + cbn - [Fancy.spec ident.gen_interp Fancy.cc_spec]; + repeat match goal with H : start_ctx _ = _ |- _ => rewrite H end; + cbv - [ + Z.add_with_get_carry_full + Z.add_get_carry_full Z.sub_get_borrow_full + Z.le Z.ltb Z.leb Z.geb Z.eqb Z.land Z.shiftr Z.shiftl + Z.add Z.mul Z.div Z.sub Z.modulo Z.testbit Z.pow Z.ones + fst snd]; cbn [fst snd]; + try (replace (2 ^ (256 / 2) - 1) with (Z.ones 128) by reflexivity; rewrite !Z.land_ones by omega); + autorewrite with to_div_mod; rewrite ?Z.mod_mod, <-?Z.testbit_spec' by omega; + repeat match goal with + | H : 0 <= ?x < ?m |- context [?x mod ?m] => rewrite (Z.mod_small x m) by apply H + | |- context [?x <? 0] => rewrite (proj2 (Z.ltb_ge x 0)) by (break_match; Z.zero_bounds) + | _ => rewrite Z.mod_small with (b:=2) by (break_match; omega) + | |- context [ (if Z.testbit ?a ?n then 1 else 0) + ?b + ?c] => + replace ((if Z.testbit a n then 1 else 0) + b + c) with (b + c + (if Z.testbit a n then 1 else 0)) by ring + end. + + Local Ltac solve_nonneg ctx := + match goal with x := (Fancy.spec _ _ _) |- _ => subst x end; + simplify_op_equiv ctx; Z.zero_bounds. + + Local Ltac generalize_result := + let v := fresh "v" in intro v; generalize v; clear v; intro v. + + Local Ltac generalize_result_nonneg ctx := + let v := fresh "v" in + let v_nonneg := fresh "v_nonneg" in + intro v; assert (0 <= v) as v_nonneg; [solve_nonneg ctx |generalize v v_nonneg; clear v v_nonneg; intros v v_nonneg]. + + Local Ltac step_abs := + match goal with + | [ |- context G[expr.interp ?ident_interp (expr.Abs ?f) ?x] ] + => let G' := context G[expr.interp ident_interp (f x)] in + change G'; cbv beta + end. + Local Ltac step ctx := + repeat step_abs; + match goal with + | |- Fancy.interp _ _ _ (Fancy.Instr (Fancy.ADD _) _ _ (Fancy.Instr (Fancy.ADDC _) _ _ _)) _ _ = _ => + apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result_nonneg ctx] + | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZ _ _ _ _ _ _) ] + => apply interp_equivZ_256; [simplify_op_equiv ctx | generalize_result] + | [ |- _ = expr.interp _ (PreFancy.LetInAppIdentZZ _ _ _ _ _ _) ] + => apply interp_equivZZ_256; [ simplify_op_equiv ctx | simplify_op_equiv ctx | generalize_result] + end. + + (* TODO: move this lemma to ZUtil *) + Lemma testbit_neg_eq_if x y n : + 0 <= n -> + 0 <= x < 2 ^ n -> + 0 <= y < 2 ^ n -> + Z.b2z (if (x - y) <? 0 then true else Z.testbit (x - y) n) = - ((x - y) / 2 ^ n) mod 2. + Proof. + intros. rewrite Z.sub_pos_bound_div_eq by omega. + break_innermost_match; Z.ltb_to_lt; try lia; try reflexivity; [ ]. + rewrite Z.testbit_eqb, Z.div_between_0_if by omega. + break_innermost_match; Z.ltb_to_lt; try lia; reflexivity. + Qed. + + Local Ltac break_ifs := + repeat (break_innermost_match_step; Z.ltb_to_lt; try (exfalso; omega); []). + + Lemma prod_montred256_correct : + forall (cc_start_state : Fancy.CC.state) (* starting carry flags can be anything *) + (start_context : register -> Z) (* starting register values *) + (lo hi y t1 t2 scratch RegPInv extra_reg : register), (* registers to use in computation *) + NoDup [lo; hi; y; t1; t2; scratch; RegPInv; extra_reg; RegMod; RegZero] -> (* registers must be distinct *) + start_context RegPInv = N' -> (* RegPInv needs to hold the inverse of the modulus *) + start_context RegMod = N -> (* RegMod needs to hold the modulus *) + start_context RegZero = 0 -> (* RegZero needs to hold zero *) + (0 <= start_context lo < R) -> (* low half of the input is in bounds (R=2^256) *) + (0 <= start_context hi < R) -> (* high half of the input is in bounds (R=2^256) *) + let x := (start_context lo) + R * (start_context hi) in (* x is the input (split into two registers) *) + (0 <= x < R * N) -> (* input precondition *) + (ProdEquiv.interp256 (Prod.MontRed256 lo hi y t1 t2 scratch RegPInv) cc_start_state start_context = (x * R') mod N). + Proof. + intros. subst x. cbv [N R N'] in *. + rewrite <-montred256_correct_full by (auto; vm_compute; reflexivity). + rewrite <-montred256_alloc_equivalent with (errorR := RegZero) (errorP := 1%positive) (extra_reg:=extra_reg) + by (cbv [R]; auto with omega). + cbv [ProdEquiv.interp256]. + cbv [montred256_alloc montred256 expr.Interp]. + + step start_context; [ break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | ]. + (*step start_context; [ break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | break_ifs; reflexivity | ]. + step start_context; [ break_ifs; reflexivity | break_ifs; reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ reflexivity | reflexivity | ]. + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + step start_context; [ reflexivity | | ]. + { + let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity. + rewrite !Z.shiftl_0_r, !Z.mod_mod by omega. + apply testbit_neg_eq_if; + let r := eval cbv in (2^256) in replace (2^256) with r by reflexivity; + auto using Z.mod_pos_bound with omega. } + step start_context; [ break_innermost_match; Z.ltb_to_lt; omega | ]. + reflexivity. + *) + Admitted. + + Import PrintingNotations. + Set Printing Width 10000. + + Print montred256. +(* +montred256 = fun var : type -> Type => (λ x : var (type.type_primitive type.Z * type.type_primitive type.Z)%ctype, + expr_let x0 := 79228162514264337593543950337 *₂₅₆ (uint128)(x₁ >> 128) in + expr_let x1 := 340282366841710300986003757985643364352 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in + expr_let x2 := 79228162514264337593543950337 *₂₅₆ ((uint128)(x₁) & 340282366920938463463374607431768211455) in + expr_let x3 := ADD_256 ((uint256)(((uint128)(x1) & 340282366920938463463374607431768211455) << 128), x2) in + expr_let x4 := ADD_256 ((uint256)(((uint128)(x0) & 340282366920938463463374607431768211455) << 128), x3₁) in + expr_let x5 := 79228162514264337593543950335 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in + expr_let x6 := 79228162514264337593543950335 *₂₅₆ (uint128)(x4₁ >> 128) in + expr_let x7 := 340282366841710300967557013911933812736 *₂₅₆ ((uint128)(x4₁) & 340282366920938463463374607431768211455) in + expr_let x8 := 340282366841710300967557013911933812736 *₂₅₆ (uint128)(x4₁ >> 128) in + expr_let x9 := ADD_256 ((uint256)(((uint128)(x7) & 340282366920938463463374607431768211455) << 128), x5) in + expr_let x10 := ADDC_256 (x9₂, (uint128)(x7 >> 128), x8) in + expr_let x11 := ADD_256 ((uint256)(((uint128)(x6) & 340282366920938463463374607431768211455) << 128), x9₁) in + expr_let x12 := ADDC_256 (x11₂, (uint128)(x6 >> 128), x10₁) in + expr_let x13 := ADD_256 (x11₁, x₁) in + expr_let x14 := ADDC_256 (x13₂, x12₁, x₂) in + expr_let x15 := SELC (x14₂, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951) in + expr_let x16 := SUB_256 (x14₁, x15) in + ADDM (x16₁, 0, 115792089210356248762697446949407573530086143415290314195533631308867097853951))%expr + : Expr (type.uncurry (type.type_primitive type.Z * type.type_primitive type.Z -> type.type_primitive type.Z)) +*) + + Import PreFancy. + Import PreFancy.Notations. + Local Notation "'RegMod'" := (expr.Ident (ident.Literal 115792089210356248762697446949407573530086143415290314195533631308867097853951)). + Local Notation "'RegPInv'" := (expr.Ident (ident.Literal 115792089210356248768974548684794254293921932838497980611635986753331132366849)). + Local Open Scope expr_scope. + Local Notation mulhl := (#(fancy_mulhl 256)). + Local Notation mulhh := (#(fancy_mulhh 256)). + Local Notation mulll := (#(fancy_mulll 256)). + Local Notation mullh := (#(fancy_mullh 256)). + Local Notation selc := (#(fancy_selc)). + Local Notation addm := (#(fancy_addm)). + Notation add n := (#(fancy_add 256 n)). + Notation addc n := (#(fancy_addc 256 n)). + + Print montred256. + (* +montred256 = +fun var : type -> Type => +λ x : var (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype), +mulhl@(x0, x₁, RegPInv); +mullh@(x1, x₁, RegPInv); +mulll@(x2, x₁, RegPInv); +(add 128)@(x3, x2, Lower{x1}); +(add 128)@(x4, x3₁, Lower{x0}); +mulll@(x5, RegMod, x4₁); +mullh@(x6, RegMod, x4₁); +mulhl@(x7, RegMod, x4₁); +mulhh@(x8, RegMod, x4₁); +(add 128)@(x9, x5, Lower{x7}); +(addc (-128))@(x10, carry{$x9}, x8, x7); +(add 128)@(x11, x9₁, Lower{x6}); +(addc (-128))@(x12, carry{$x11}, x10₁, x6); +(add 0)@(x13, x11₁, x₁); +(addc 0)@(x14, carry{$x13}, x12₁, x₂); +selc@(x15, (carry{$x14}, RegZero), RegMod); +#(fancy_sub 256 0)@(x16, x14₁, x15); +addm@(x17, (x16₁, RegZero), RegMod); +x17 + : Expr + (type.base (base.type.type_base base.type.Z * base.type.type_base base.type.Z)%etype -> + type.base (base.type.type_base base.type.Z))%ptype + *) + +End Montgomery256. + +Local Notation "i rd x y ; cont" := (Fancy.Instr i rd (x, y) cont) (at level 40, cont at level 200, format "i rd x y ; '//' cont"). +Local Notation "i rd x y z ; cont" := (Fancy.Instr i rd (x, y, z) cont) (at level 40, cont at level 200, format "i rd x y z ; '//' cont"). + +Import Fancy.Registers. +Import Fancy. + +Import Barrett256 Montgomery256. + +(*** Montgomery Reduction ***) + +(* Status: Code in final form is proven correct modulo admits in compiler portions. *) + +(* Montgomery Code : *) +Eval cbv beta iota delta [Prod.MontRed256 Prod.Mul256 Prod.Mul256x256] in Prod.MontRed256. +(* + = fun lo hi y t1 t2 scratch RegPInv : register => + MUL128LL y lo RegPInv; + MUL128UL t1 lo RegPInv; + ADD 128 y y t1; + MUL128LU t1 lo RegPInv; + ADD 128 y y t1; + MUL128LL t1 y RegMod; + MUL128UU t2 y RegMod; + MUL128UL scratch y RegMod; + ADD 128 t1 t1 scratch; + ADDC (-128) t2 t2 scratch; + MUL128LU scratch y RegMod; + ADD 128 t1 t1 scratch; + ADDC (-128) t2 t2 scratch; + ADD 0 lo lo t1; + ADDC 0 hi hi t2; + SELC y RegMod RegZero; + SUB 0 lo hi y; + ADDM lo lo RegZero RegMod; + Ret lo + *) + +(* Uncomment to see proof statement and remaining admitted statements, +or search for "prod_montred256_correct" to see comments on the proof +preconditions. *) +(* +Check Montgomery256.prod_montred256_correct. +Print Assumptions Montgomery256.prod_montred256_correct. +*) + +(*** Barrett Reduction ***) + +(* Status: Code is proven correct modulo admits in compiler +portions. However, unlike for Montgomery, this code is not proven +equivalent to the register-allocated and efficiently-scheduled +reference (Prod.MulMod). This proof is currently admitted and would +require either fiddling with code generation to make instructions come +out in the right order or reasoning about which instructions +commute. *) + +(* Barrett reference code: *) +Eval cbv beta iota delta [Prod.MulMod Prod.Mul256x256] in Prod.MulMod. +(* + = fun x xHigh RegMuLow scratchp1 scratchp2 scratchp3 scratchp4 scratchp5 : register => + let q1Bottom256 := scratchp1 in + let muSelect := scratchp2 in + let q2 := scratchp3 in + let q2High := scratchp4 in + let q2High2 := scratchp5 in + let q3 := scratchp1 in + let r2 := scratchp2 in + let r2High := scratchp3 in + let maybeM := scratchp1 in + SELM muSelect RegMuLow RegZero; + RSHI 255 q1Bottom256 xHigh x; + MUL128LL q2 q1Bottom256 RegMuLow; + MUL128UU q2High q1Bottom256 RegMuLow; + MUL128UL scratchp5 q1Bottom256 RegMuLow; + ADD 128 q2 q2 scratchp5; + ADDC (-128) q2High q2High scratchp5; + MUL128LU scratchp5 q1Bottom256 RegMuLow; + ADD 128 q2 q2 scratchp5; + ADDC (-128) q2High q2High scratchp5; + RSHI 255 q2High2 RegZero xHigh; + ADD 0 q2High q2High q1Bottom256; + ADDC 0 q2High2 q2High2 RegZero; + ADD 0 q2High q2High muSelect; + ADDC 0 q2High2 q2High2 RegZero; + RSHI 1 q3 q2High2 q2High; + MUL128LL r2 RegMod q3; + MUL128UU r2High RegMod q3; + MUL128UL scratchp4 RegMod q3; + ADD 128 r2 r2 scratchp4; + ADDC (-128) r2High r2High scratchp4; + MUL128LU scratchp4 RegMod q3; + ADD 128 r2 r2 scratchp4; + ADDC (-128) r2High r2High scratchp4; + SUB 0 muSelect x r2; + SUBC 0 xHigh xHigh r2High; + SELL maybeM RegMod RegZero; + SUB 0 q3 muSelect maybeM; + ADDM x q3 RegZero RegMod; + Ret x + *) + +(* Barrett generated code (equivalence with reference admitted) *) +Eval cbv beta iota delta [barrett_red256_alloc] in barrett_red256_alloc. +(* + = fun (xLow xHigh RegMuLow : register) (_ : positive) (_ : register) => + SELM r2 RegMuLow RegZero; + RSHI 255 r3 RegZero xHigh; + RSHI 255 r4 xHigh xLow; + MUL128UU r5 RegMuLow r4; + MUL128UL r6 r4 RegMuLow; + MUL128LU r7 r4 RegMuLow; + MUL128LL r8 RegMuLow r4; + ADD 128 r9 r8 r7; + ADDC (-128) r10 r5 r7; + ADD 128 r5 r9 r6; + ADDC (-128) r11 r10 r6; + ADD 0 r6 r4 r11; + ADDC 0 r12 RegZero r3; + ADD 0 r13 r2 r6; + ADDC 0 r14 RegZero r12; + RSHI 1 r15 r14 r13; + MUL128UU r16 RegMod r15; + MUL128LU r17 r15 RegMod; + MUL128UL r18 r15 RegMod; + MUL128LL r19 RegMod r15; + ADD 128 r20 r19 r18; + ADDC (-128) r21 r16 r18; + ADD 128 r22 r20 r17; + ADDC (-128) r23 r21 r17; + SUB 0 r24 xLow r22; + SUBC 0 r25 xHigh r23; + SELL r26 RegMod RegZero; + SUB 0 r27 r24 r26; + ADDM r28 r27 RegZero RegMod; + Ret r28 + *) + +(* Uncomment to see proof statement and remaining admitted statements. *) +(* +Check prod_barrett_red256_correct. +Print Assumptions prod_barrett_red256_correct. +(* The equivalence with generated code is admitted as barrett_red256_alloc_equivalent. *) +*) |