From b5b1eebe2845b0e69d669b51cea9eeb67ea726e3 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Thu, 11 Aug 2016 15:41:00 -0700 Subject: Initial work on an architecture interface for ℤ/nℤ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This provides a cleaner interface for the bottom level implementation, as well as an implementation of multiplying 128x128 -> 256. --- src/BoundedArithmetic/ArchitectureToZLike.v | 38 ++++++++ src/BoundedArithmetic/DoubleBounded.v | 114 ++++++++++++++++++++++++ src/BoundedArithmetic/Interface.v | 131 ++++++++++++++++++++++++++++ 3 files changed, 283 insertions(+) create mode 100644 src/BoundedArithmetic/ArchitectureToZLike.v create mode 100644 src/BoundedArithmetic/DoubleBounded.v create mode 100644 src/BoundedArithmetic/Interface.v (limited to 'src') diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v new file mode 100644 index 000000000..6c92f342f --- /dev/null +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -0,0 +1,38 @@ +(*** Implementing ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.ModularArithmetic.ZBounded. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. + +Local Existing Instances DoubleArchitectureBoundedOps DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps. + +Section ops. + Context {n_over_two : size}. + Local Notation n := (2 * n_over_two)%nat. + Context (ops : ArchitectureBoundedOps n) (mops : ArchitectureBoundedHalfWidthMulOps n) + (modulus : Z). + + Axiom admit : forall {T}, T. + + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : size) + : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := + { LargeT := @BoundedType (2 * n)%nat _; + SmallT := @BoundedType n _; + modulus_digits := encode modulus; + decode_large := decode; + decode_small := decode; + Mod_SmallBound v := snd v; + DivBy_SmallBound v := fst v; + DivBy_SmallerBound v := ShiftRight smaller_bound_exp v; + Mul x y := @Interface.Mul n _ _ x y; + CarryAdd x y := Interface.CarryAdd false x y; + CarrySubSmall x y := Interface.CarrySub false x y; + ConditionalSubtract b x := admit; + ConditionalSubtractModulus y := admit }. +End ops. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v new file mode 100644 index 000000000..59d961d4a --- /dev/null +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -0,0 +1,114 @@ +(*** Implementing Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith Coq.Lists.List. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. + +Import ListNotations. +Local Open Scope list_scope. +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Notation eta x := (fst x, snd x). + +Section with_single_ops. + Section single. + Context (n : size) {base_ops : ArchitectureBoundedOps n}. + + Definition ripple_carry {T} (f : bool -> T -> T -> bool * T) + (carry : bool) (xs ys : list T) : bool * list T + := List.fold_right + (fun x_y carry_zs => let '(x, y) := eta x_y in + let '(carry, zs) := eta carry_zs in + let '(carry, z) := eta (f carry x y) in + (carry, z :: zs)) + (carry, nil) + (List.combine xs ys). + + Definition two_list_to_tuple {A B} (x : A * list B) + := match x return match x with + | (a, [b0; b1]) => A * (B * B) + | _ => True + end + with + | (a, [b0; b1]) => (a, (b0, b1)) + | _ => I + end. + + Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat + := { BoundedType := BoundedType * BoundedType (* [(high, low)] *); + decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z; + encode z := (encode (z / 2^n), encode (z mod 2^n)); + ShiftRight a high_low + := let '(high, low) := eta high_low in + if n <=? a then + (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high) + else + (ShiftRight a (snd high, fst low), ShiftRight a low); + ShiftLeft a high_low + := let '(high, low) := eta high_low in + if 2 * n <=? a then + let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in + let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in + ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0)) + else if n <=? a then + let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in + let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in + ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0)) + else + let '(high0, low) := eta (ShiftLeft a low) in + let '(high_high, high1) := eta (ShiftLeft a high) in + ((encode 0, high_high), (snd (CarryAdd false high0 high1), low)); + Mod2Pow a high_low + := let '(high, low) := (fst high_low, snd high_low) in + (Mod2Pow (a - n)%nat high, Mod2Pow a low); + CarryAdd carry x_high_low y_high_low + := let '(xhigh, xlow) := eta x_high_low in + let '(yhigh, ylow) := eta y_high_low in + two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]); + CarrySub carry x_high_low y_high_low + := let '(xhigh, xlow) := eta x_high_low in + let '(yhigh, ylow) := eta y_high_low in + two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }. + + Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _ + := match x with + | UpperHalf x => fst x + | LowerHalf x => snd x + end. + + Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps + {base_mops : ArchitectureBoundedFullMulOps n} + : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat := + { HalfWidthMul a b + := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }. + End single. + + Local Existing Instance DoubleArchitectureBoundedOps. + + Section full_from_half. + Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}. + + Local Infix "*" := HalfWidthMul. + + Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps + {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat} + : ArchitectureBoundedFullMulOps (2 * n)%nat := + { Mul a b + := let '(a1, a0) := (UpperHalf a, LowerHalf a) in + let '(b1, b0) := (UpperHalf b, LowerHalf b) in + let out := a0 * b0 in + let outHigh := a1 * b1 in + let tmp := a1 * b0 in + let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in + let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in + let tmp := a0 * b1 in + let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in + let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in + (outHigh, out) }. + End full_from_half. + + Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. +End with_single_ops. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v new file mode 100644 index 000000000..e9ade7ad4 --- /dev/null +++ b/src/BoundedArithmetic/Interface.v @@ -0,0 +1,131 @@ +(*** Interface for bounded arithmetic *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Definition size := nat. + +Local Coercion Z.of_nat : nat >-> Z. + +Class ArchitectureBoundedOps (n : size) := + { BoundedType : Type (* [n]-bit word *); + decode : BoundedType -> Z; + encode : Z -> BoundedType; + ShiftRight : forall a : size, BoundedType * BoundedType -> BoundedType; + (** given [(high, low)], constructs [(high << (n - a)) + (low >> + a)], i.e., shifts [high * 2ⁿ + low] down by [a] bits *) + ShiftLeft : forall a : size, BoundedType -> BoundedType * BoundedType; + (** given [x], constructs [(((x << a) / 2ⁿ) mod 2ⁿ, (x << a) mod + 2ⁿ], i.e., shifts [x] up by [a] bits, and takes the low [2n] + bits of the result *) + Mod2Pow : forall a : size, BoundedType -> BoundedType (* [mod 2ᵃ] *); + CarryAdd : forall (carry : bool) (x y : BoundedType), bool * BoundedType; + (** Ouputs [(x + y + if carry then 1 else 0) mod 2ⁿ], together + with a boolean that's [true] if the sum is ≥ 2ⁿ, and [false] + if there is no carry *) + CarrySub : forall (carry : bool) (x y : BoundedType), bool * BoundedType; + (** Ouputs [(x - y - if carry then 1 else 0) mod 2ⁿ], together + with a boolean that's [true] if the sum is negative, and [false] + if there is no borrow *) }. + +Inductive BoundedHalfType {n} {ops : ArchitectureBoundedOps n} := +| UpperHalf (_ : BoundedType) +| LowerHalf (_ : BoundedType). + +Definition UnderlyingBounded {n} {ops : ArchitectureBoundedOps n} (x : BoundedHalfType) + := match x with + | UpperHalf v => v + | LowerHalf v => v + end. + +Definition decode_half {n_over_two : size} {ops : ArchitectureBoundedOps (2 * n_over_two)%nat} (x : BoundedHalfType) : Z + := match x with + | UpperHalf v => decode v / 2^n_over_two + | LowerHalf v => (decode v) mod 2^n_over_two + end. + +Class ArchitectureBoundedFullMulOps n {ops : ArchitectureBoundedOps n} := + { Mul : BoundedType -> BoundedType -> BoundedType * BoundedType + (** Outputs [(high, low)] *) }. +Class ArchitectureBoundedHalfWidthMulOps n {ops : ArchitectureBoundedOps n} := + { HalfWidthMul : BoundedHalfType -> BoundedHalfType -> BoundedType }. + +Class ArchitectureBoundedProperties {n} (ops : ArchitectureBoundedOps n) := + { bounded_valid : BoundedType -> Prop; + decode_valid : forall v, + bounded_valid v + -> 0 <= decode v < 2^n; + encode_valid : forall z, + 0 <= z < 2^n + -> bounded_valid (encode z); + encode_correct : forall z, + 0 <= z < 2^n + -> decode (encode z) = z; + ShiftRight_valid : forall a high_low, + bounded_valid (fst high_low) -> bounded_valid (snd high_low) + -> bounded_valid (ShiftRight a high_low); + ShiftRight_correct : forall a high_low, + bounded_valid (fst high_low) -> bounded_valid (snd high_low) + -> decode (ShiftRight a high_low) = (decode (fst high_low) * 2^n + decode (snd high_low)) / 2^a; + ShiftLeft_fst_valid : forall a v, + bounded_valid v + -> bounded_valid (fst (ShiftLeft a v)); + ShiftLeft_snd_valid : forall a v, + bounded_valid v + -> bounded_valid (snd (ShiftLeft a v)); + ShiftLeft_fst_correct : forall a v, + bounded_valid v + -> decode (fst (ShiftLeft a v)) = (decode v * 2^a) mod 2^n; + ShiftLeft_snd_correct : forall a v, + bounded_valid v + -> decode (snd (ShiftLeft a v)) = ((decode v * 2^a) / 2^n) mod 2^n; + Mod2Pow_valid : forall a v, + bounded_valid v + -> bounded_valid (Mod2Pow a v); + Mod2Pow_correct : forall a v, + bounded_valid v + -> decode (Mod2Pow a v) = (decode v) mod 2^a; + CarryAdd_valid : forall c x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (snd (CarryAdd c x y)); + CarryAdd_fst_correct : forall c x y, + bounded_valid x -> bounded_valid y + -> fst (CarryAdd c x y) = (2^n <=? (decode x + decode y + if c then 1 else 0)); + CarryAdd_snd_correct : forall c x y, + bounded_valid x -> bounded_valid y + -> decode (snd (CarryAdd c x y)) = (decode x + decode y + if c then 1 else 0) mod 2^n; + CarrySub_valid : forall c x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (snd (CarrySub c x y)); + CarrySub_fst_correct : forall c x y, + bounded_valid x -> bounded_valid y + -> fst (CarrySub c x y) = ((decode x - decode y - if c then 1 else 0) bounded_valid y + -> decode (snd (CarrySub c x y)) = (decode x - decode y - if c then 1 else 0) mod 2^n }. + +Class ArchitectureBoundedFullMulProperties {n ops} (mops : @ArchitectureBoundedFullMulOps n ops) {props : ArchitectureBoundedProperties ops} := + { Mul_fst_valid : forall x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (fst (Mul x y)); + Mul_snd_valid : forall x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (snd (Mul x y)); + Mul_high_correct : forall x y, + bounded_valid x -> bounded_valid y + -> decode (fst (Mul x y)) = (decode x * decode y) / 2^n; + Mul_low_correct : forall x y, + bounded_valid x -> bounded_valid y + -> decode (snd (Mul x y)) = (decode x * decode y) mod 2^n }. + +Class ArchitectureBoundedHalfWidthMulProperties {n_over_two ops} (mops : @ArchitectureBoundedHalfWidthMulOps (2 * n_over_two)%nat ops) {props : ArchitectureBoundedProperties ops} := + { HalfWidthMul_valid : forall x y, + bounded_valid (UnderlyingBounded x) -> bounded_valid (UnderlyingBounded y) + -> bounded_valid (HalfWidthMul x y); + HalfWidthMul_correct : forall x y, + bounded_valid (UnderlyingBounded x) -> bounded_valid (UnderlyingBounded y) + -> decode (HalfWidthMul x y) = (decode_half x * decode_half y)%Z }. -- cgit v1.2.3 From 07b18ae2cb1122f395bffdf706ad37248bc5d4dc Mon Sep 17 00:00:00 2001 From: Andres Erbsen Date: Fri, 12 Aug 2016 11:21:45 -0400 Subject: alternative machine interface specification proposal --- src/BoundedArithmetic/ArchitectureToZLike.v | 98 +++++++++++++- src/BoundedArithmetic/DoubleBounded.v | 114 +++------------- src/BoundedArithmetic/Interface.v | 193 ++++++++++++---------------- 3 files changed, 190 insertions(+), 215 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index 6c92f342f..ebb21bd4b 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -3,6 +3,8 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BoundedArithmetic.DoubleBounded. Require Import Crypto.ModularArithmetic.ZBounded. +Require Import Coq.Lists.List. +Import ListNotations. Local Open Scope nat_scope. Local Open Scope Z_scope. @@ -10,13 +12,97 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. -Local Existing Instances DoubleArchitectureBoundedOps DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps. - -Section ops. +Section fancy_machine_p256_montgomery_foundation. Context {n_over_two : size}. Local Notation n := (2 * n_over_two)%nat. - Context (ops : ArchitectureBoundedOps n) (mops : ArchitectureBoundedHalfWidthMulOps n) - (modulus : Z). + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Definition two_list_to_tuple {A B} (x : A * list B) + := match x return match x with + | (a, [b0; b1]) => A * (B * B) + | _ => True + end + with + | (a, [b0; b1]) => (a, (b0, b1)) + | _ => I + end. + + (* make all machine-specific constructions here, preferrably as + thing wrappers around generic constructions *) + Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat + := { BoundedType := BoundedType * BoundedType (* [(high, low)] *); + decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z; + encode z := (encode (z / 2^n), encode (z mod 2^n)); + ShiftRight a high_low + := let '(high, low) := eta high_low in + if n <=? a then + (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high) + else + (ShiftRight a (snd high, fst low), ShiftRight a low); + ShiftLeft a high_low + := let '(high, low) := eta high_low in + if 2 * n <=? a then + let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in + let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in + ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0)) + else if n <=? a then + let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in + let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in + ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0)) + else + let '(high0, low) := eta (ShiftLeft a low) in + let '(high_high, high1) := eta (ShiftLeft a high) in + ((encode 0, high_high), (snd (CarryAdd false high0 high1), low)); + Mod2Pow a high_low + := let '(high, low) := (fst high_low, snd high_low) in + (Mod2Pow (a - n)%nat high, Mod2Pow a low); + CarryAdd carry x_high_low y_high_low + := let '(xhigh, xlow) := eta x_high_low in + let '(yhigh, ylow) := eta y_high_low in + two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]); + CarrySub carry x_high_low y_high_low + := let '(xhigh, xlow) := eta x_high_low in + let '(yhigh, ylow) := eta y_high_low in + two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }. + + Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _ + := match x with + | UpperHalf x => fst x + | LowerHalf x => snd x + end. + + Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps + {base_mops : ArchitectureBoundedFullMulOps n} + : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat := + { HalfWidthMul a b + := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }. + End single. + + Local Existing Instance DoubleArchitectureBoundedOps. + + Section full_from_half. + Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}. + + Local Infix "*" := HalfWidthMul. + + Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps + {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat} + : ArchitectureBoundedFullMulOps (2 * n)%nat := + { Mul a b + := let '(a1, a0) := (UpperHalf a, LowerHalf a) in + let '(b1, b0) := (UpperHalf b, LowerHalf b) in + let out := a0 * b0 in + let outHigh := a1 * b1 in + let tmp := a1 * b0 in + let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in + let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in + let tmp := a0 * b1 in + let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in + let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in + (outHigh, out) }. + End full_from_half. + + Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. Axiom admit : forall {T}, T. @@ -36,3 +122,5 @@ Section ops. ConditionalSubtract b x := admit; ConditionalSubtractModulus y := admit }. End ops. + +End with_single_ops. \ No newline at end of file diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 59d961d4a..0d9c8e860 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -13,102 +13,18 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. Local Notation eta x := (fst x, snd x). -Section with_single_ops. - Section single. - Context (n : size) {base_ops : ArchitectureBoundedOps n}. - - Definition ripple_carry {T} (f : bool -> T -> T -> bool * T) - (carry : bool) (xs ys : list T) : bool * list T - := List.fold_right - (fun x_y carry_zs => let '(x, y) := eta x_y in - let '(carry, zs) := eta carry_zs in - let '(carry, z) := eta (f carry x y) in - (carry, z :: zs)) - (carry, nil) - (List.combine xs ys). - - Definition two_list_to_tuple {A B} (x : A * list B) - := match x return match x with - | (a, [b0; b1]) => A * (B * B) - | _ => True - end - with - | (a, [b0; b1]) => (a, (b0, b1)) - | _ => I - end. - - Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat - := { BoundedType := BoundedType * BoundedType (* [(high, low)] *); - decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z; - encode z := (encode (z / 2^n), encode (z mod 2^n)); - ShiftRight a high_low - := let '(high, low) := eta high_low in - if n <=? a then - (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high) - else - (ShiftRight a (snd high, fst low), ShiftRight a low); - ShiftLeft a high_low - := let '(high, low) := eta high_low in - if 2 * n <=? a then - let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in - let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in - ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0)) - else if n <=? a then - let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in - let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in - ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0)) - else - let '(high0, low) := eta (ShiftLeft a low) in - let '(high_high, high1) := eta (ShiftLeft a high) in - ((encode 0, high_high), (snd (CarryAdd false high0 high1), low)); - Mod2Pow a high_low - := let '(high, low) := (fst high_low, snd high_low) in - (Mod2Pow (a - n)%nat high, Mod2Pow a low); - CarryAdd carry x_high_low y_high_low - := let '(xhigh, xlow) := eta x_high_low in - let '(yhigh, ylow) := eta y_high_low in - two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]); - CarrySub carry x_high_low y_high_low - := let '(xhigh, xlow) := eta x_high_low in - let '(yhigh, ylow) := eta y_high_low in - two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }. - - Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _ - := match x with - | UpperHalf x => fst x - | LowerHalf x => snd x - end. - - Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps - {base_mops : ArchitectureBoundedFullMulOps n} - : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat := - { HalfWidthMul a b - := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }. - End single. - - Local Existing Instance DoubleArchitectureBoundedOps. - - Section full_from_half. - Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}. - - Local Infix "*" := HalfWidthMul. - - Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps - {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat} - : ArchitectureBoundedFullMulOps (2 * n)%nat := - { Mul a b - := let '(a1, a0) := (UpperHalf a, LowerHalf a) in - let '(b1, b0) := (UpperHalf b, LowerHalf b) in - let out := a0 * b0 in - let outHigh := a1 * b1 in - let tmp := a1 * b0 in - let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in - let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in - let tmp := a0 * b1 in - let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in - let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in - (outHigh, out) }. - End full_from_half. - - Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. -End with_single_ops. +Section generic_constructions. + Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) + (carry : bool) (xs ys : list T) : bool * list T + := List.fold_right + (fun x_y carry_zs => let '(x, y) := eta x_y in + let '(carry, zs) := eta carry_zs in + let '(carry, z) := eta (f x y carry) in + (carry, z :: zs)) + (carry, nil) + (List.combine xs ys). + + (* TODO: Would it made sense to make generic-width shift operations here? *) + + (* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *) +End generic_constructions. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index e9ade7ad4..f7906ccb5 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -11,121 +11,92 @@ Definition size := nat. Local Coercion Z.of_nat : nat >-> Z. -Class ArchitectureBoundedOps (n : size) := - { BoundedType : Type (* [n]-bit word *); - decode : BoundedType -> Z; - encode : Z -> BoundedType; - ShiftRight : forall a : size, BoundedType * BoundedType -> BoundedType; - (** given [(high, low)], constructs [(high << (n - a)) + (low >> - a)], i.e., shifts [high * 2ⁿ + low] down by [a] bits *) - ShiftLeft : forall a : size, BoundedType -> BoundedType * BoundedType; - (** given [x], constructs [(((x << a) / 2ⁿ) mod 2ⁿ, (x << a) mod - 2ⁿ], i.e., shifts [x] up by [a] bits, and takes the low [2n] - bits of the result *) - Mod2Pow : forall a : size, BoundedType -> BoundedType (* [mod 2ᵃ] *); - CarryAdd : forall (carry : bool) (x y : BoundedType), bool * BoundedType; - (** Ouputs [(x + y + if carry then 1 else 0) mod 2ⁿ], together - with a boolean that's [true] if the sum is ≥ 2ⁿ, and [false] - if there is no carry *) - CarrySub : forall (carry : bool) (x y : BoundedType), bool * BoundedType; - (** Ouputs [(x - y - if carry then 1 else 0) mod 2ⁿ], together - with a boolean that's [true] if the sum is negative, and [false] - if there is no borrow *) }. +Section InstructionGallery. + Context (n:Z) (* bit-width of width of W *) + {W : Type}(* previously [BoundedType], [W] for word *) + (decode : W -> Z). + Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) -Inductive BoundedHalfType {n} {ops : ArchitectureBoundedOps n} := -| UpperHalf (_ : BoundedType) -| LowerHalf (_ : BoundedType). + Class is_load_immediate (ldi:imm->W) := + decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. -Definition UnderlyingBounded {n} {ops : ArchitectureBoundedOps n} (x : BoundedHalfType) - := match x with - | UpperHalf v => v - | LowerHalf v => v - end. + Class is_shift_right_doubleword_immediate (shrd:W->W->imm->W) := + decode_shift_right_doubleword : + forall high low count, + 0 <= count < n + -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. -Definition decode_half {n_over_two : size} {ops : ArchitectureBoundedOps (2 * n_over_two)%nat} (x : BoundedHalfType) : Z - := match x with - | UpperHalf v => decode v / 2^n_over_two - | LowerHalf v => (decode v) mod 2^n_over_two - end. + Class is_shift_left_immediate (shl:W->imm->W) := + decode_shift_left_immediate : + forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. -Class ArchitectureBoundedFullMulOps n {ops : ArchitectureBoundedOps n} := - { Mul : BoundedType -> BoundedType -> BoundedType * BoundedType - (** Outputs [(high, low)] *) }. -Class ArchitectureBoundedHalfWidthMulOps n {ops : ArchitectureBoundedOps n} := - { HalfWidthMul : BoundedHalfType -> BoundedHalfType -> BoundedType }. + Class is_spread_left_immediate (sprl:W->imm->W*W(*high, low*)) := + { + fst_spread_left_immediate : forall r count, 0 <= count < n -> + decode (fst (sprl r count)) = (decode r << count) >> n; + snd_spread_left_immediate : forall r count, 0 <= count < n -> + decode (snd (sprl r count)) = (decode r << count) mod 2^n + }. -Class ArchitectureBoundedProperties {n} (ops : ArchitectureBoundedOps n) := - { bounded_valid : BoundedType -> Prop; - decode_valid : forall v, - bounded_valid v - -> 0 <= decode v < 2^n; - encode_valid : forall z, - 0 <= z < 2^n - -> bounded_valid (encode z); - encode_correct : forall z, - 0 <= z < 2^n - -> decode (encode z) = z; - ShiftRight_valid : forall a high_low, - bounded_valid (fst high_low) -> bounded_valid (snd high_low) - -> bounded_valid (ShiftRight a high_low); - ShiftRight_correct : forall a high_low, - bounded_valid (fst high_low) -> bounded_valid (snd high_low) - -> decode (ShiftRight a high_low) = (decode (fst high_low) * 2^n + decode (snd high_low)) / 2^a; - ShiftLeft_fst_valid : forall a v, - bounded_valid v - -> bounded_valid (fst (ShiftLeft a v)); - ShiftLeft_snd_valid : forall a v, - bounded_valid v - -> bounded_valid (snd (ShiftLeft a v)); - ShiftLeft_fst_correct : forall a v, - bounded_valid v - -> decode (fst (ShiftLeft a v)) = (decode v * 2^a) mod 2^n; - ShiftLeft_snd_correct : forall a v, - bounded_valid v - -> decode (snd (ShiftLeft a v)) = ((decode v * 2^a) / 2^n) mod 2^n; - Mod2Pow_valid : forall a v, - bounded_valid v - -> bounded_valid (Mod2Pow a v); - Mod2Pow_correct : forall a v, - bounded_valid v - -> decode (Mod2Pow a v) = (decode v) mod 2^a; - CarryAdd_valid : forall c x y, - bounded_valid x -> bounded_valid y - -> bounded_valid (snd (CarryAdd c x y)); - CarryAdd_fst_correct : forall c x y, - bounded_valid x -> bounded_valid y - -> fst (CarryAdd c x y) = (2^n <=? (decode x + decode y + if c then 1 else 0)); - CarryAdd_snd_correct : forall c x y, - bounded_valid x -> bounded_valid y - -> decode (snd (CarryAdd c x y)) = (decode x + decode y + if c then 1 else 0) mod 2^n; - CarrySub_valid : forall c x y, - bounded_valid x -> bounded_valid y - -> bounded_valid (snd (CarrySub c x y)); - CarrySub_fst_correct : forall c x y, - bounded_valid x -> bounded_valid y - -> fst (CarrySub c x y) = ((decode x - decode y - if c then 1 else 0) bounded_valid y - -> decode (snd (CarrySub c x y)) = (decode x - decode y - if c then 1 else 0) mod 2^n }. + Class is_mask_keep_low (mkl:W->imm->W) := + decode_mask_keep_low : forall r count, + 0 <= count < n -> decode (mkl r count) = decode r mod 2^count. -Class ArchitectureBoundedFullMulProperties {n ops} (mops : @ArchitectureBoundedFullMulOps n ops) {props : ArchitectureBoundedProperties ops} := - { Mul_fst_valid : forall x y, - bounded_valid x -> bounded_valid y - -> bounded_valid (fst (Mul x y)); - Mul_snd_valid : forall x y, - bounded_valid x -> bounded_valid y - -> bounded_valid (snd (Mul x y)); - Mul_high_correct : forall x y, - bounded_valid x -> bounded_valid y - -> decode (fst (Mul x y)) = (decode x * decode y) / 2^n; - Mul_low_correct : forall x y, - bounded_valid x -> bounded_valid y - -> decode (snd (Mul x y)) = (decode x * decode y) mod 2^n }. + Local Notation bit b := (if b then 1 else 0). + Class is_add_with_carry (adc:W->W->bool->bool*W) := + { + fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; + snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) + }. -Class ArchitectureBoundedHalfWidthMulProperties {n_over_two ops} (mops : @ArchitectureBoundedHalfWidthMulOps (2 * n_over_two)%nat ops) {props : ArchitectureBoundedProperties ops} := - { HalfWidthMul_valid : forall x y, - bounded_valid (UnderlyingBounded x) -> bounded_valid (UnderlyingBounded y) - -> bounded_valid (HalfWidthMul x y); - HalfWidthMul_correct : forall x y, - bounded_valid (UnderlyingBounded x) -> bounded_valid (UnderlyingBounded y) - -> decode (HalfWidthMul x y) = (decode_half x * decode_half y)%Z }. + Class is_sub_with_carry (subc:W->W->bool->bool*W) := + { + fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) W->W) := + decode_mul : forall x y, decode (mul x y) = decode x * decode y mod 2^n. + + Class is_mul_low_low (w:Z) (mulhwll:W->W->W) := + decode_mul_low_low : + forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + Class is_mul_high_low (w:Z) (mulhwhl:W->W->W) := + decode_mul_high_low : + forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + Class is_mul_high_high (w:Z) (mulhwhh:W->W->W) := + decode_mul_high_high : + forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. +End InstructionGallery. + +Module fancy_machine. + Local Notation imm := Z (only parsing). + Class instructions (n:Z) := + { + W : Type (* [n]-bit word *); + ldi : imm -> W; + shrd : W->W->imm -> W; + sprl : W->imm -> W*W; + mkl : W->imm -> W; + adc : W->W->bool -> bool*W; + subc : W->W->bool -> bool*W + }. + + Class arithmetic {n} (ops:instructions n) := + { + decode : W -> Z; + decode_range : forall x, 0 <= decode x < 2^n; + load_immediate : is_load_immediate n decode ldi; + shift_right_doubleword_immediate : is_shift_right_doubleword_immediate n decode shrd; + spread_left_immediate : is_spread_left_immediate n decode sprl; + mask_keep_low : is_mask_keep_low n decode mkl; + add_with_carry : is_add_with_carry n decode adc; + sub_with_carry : is_sub_with_carry n decode subc + }. + Global Existing Instance load_immediate. + Global Existing Instance shift_right_doubleword_immediate. + Global Existing Instance spread_left_immediate. + Global Existing Instance mask_keep_low. + Global Existing Instance add_with_carry. + Global Existing Instance sub_with_carry. +End fancy_machine. \ No newline at end of file -- cgit v1.2.3 From dc295c74a191d2ad9ab56a4792391a4c68a42e5d Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 12 Aug 2016 11:45:08 -0700 Subject: Rework interface to support rewriting database --- src/BoundedArithmetic/ArchitectureToZLike.v | 35 ++-- src/BoundedArithmetic/DoubleBounded.v | 14 +- src/BoundedArithmetic/Interface.v | 240 ++++++++++++++++++++++------ 3 files changed, 222 insertions(+), 67 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index ebb21bd4b..01387e969 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -13,8 +13,8 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. Section fancy_machine_p256_montgomery_foundation. - Context {n_over_two : size}. - Local Notation n := (2 * n_over_two)%nat. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two)%Z. Context (ops : fancy_machine.instructions n) (modulus : Z). Definition two_list_to_tuple {A B} (x : A * list B) @@ -26,7 +26,7 @@ Section fancy_machine_p256_montgomery_foundation. | (a, [b0; b1]) => (a, (b0, b1)) | _ => I end. - +(* (* make all machine-specific constructions here, preferrably as thing wrappers around generic constructions *) Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat @@ -103,24 +103,23 @@ Section fancy_machine_p256_montgomery_foundation. End full_from_half. Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. - +*) Axiom admit : forall {T}, T. - Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : size) + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := - { LargeT := @BoundedType (2 * n)%nat _; - SmallT := @BoundedType n _; - modulus_digits := encode modulus; - decode_large := decode; + { LargeT := fancy_machine.W * fancy_machine.W; + SmallT := fancy_machine.W; + modulus_digits := ldi modulus; + decode_large := _; decode_small := decode; Mod_SmallBound v := snd v; DivBy_SmallBound v := fst v; - DivBy_SmallerBound v := ShiftRight smaller_bound_exp v; - Mul x y := @Interface.Mul n _ _ x y; - CarryAdd x y := Interface.CarryAdd false x y; - CarrySubSmall x y := Interface.CarrySub false x y; - ConditionalSubtract b x := admit; - ConditionalSubtractModulus y := admit }. -End ops. - -End with_single_ops. \ No newline at end of file + DivBy_SmallerBound v := shrd (fst v) (snd v) smaller_bound_exp; + Mul x y := _ (*mulhwll (ldi 0, x) (ldi 0, y)*); + CarryAdd x y := _ (*adc x y false*); + CarrySubSmall x y := subc x y false; + ConditionalSubtract b x := let v := selc b (ldi 0) (ldi modulus) in snd (subc x v false); + ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }. + Abort. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 0d9c8e860..c09006e2b 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -1,6 +1,8 @@ (*** Implementing Large Bounded Arithmetic via pairs *) Require Import Coq.ZArith.ZArith Coq.Lists.List. Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BaseSystem. +Require Import Crypto.BaseSystemProofs. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. @@ -15,7 +17,7 @@ Local Notation eta x := (fst x, snd x). Section generic_constructions. Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) - (carry : bool) (xs ys : list T) : bool * list T + (xs ys : list T) (carry : bool) : bool * list T := List.fold_right (fun x_y carry_zs => let '(x, y) := eta x_y in let '(carry, zs) := eta carry_zs in @@ -24,6 +26,16 @@ Section generic_constructions. (carry, nil) (List.combine xs ys). + Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Global Instance ripple_carry_add_with_carry : add_with_carry (list W) + := {| Interface.adc := ripple_carry adc |}. + (* + Global Instance ripple_carry_is_add_with_carry {is_adc : is_add_with_carry adc} + : is_add_with_carry ripple_carry_add_with_carry.*) + End ripple_carry_adc. + (* TODO: Would it made sense to make generic-width shift operations here? *) (* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *) diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index f7906ccb5..cda72967c 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -3,100 +3,244 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. -Local Open Scope nat_scope. Local Open Scope Z_scope. Local Open Scope type_scope. -Definition size := nat. +Create HintDb push_decode discriminated. +Create HintDb pull_decode discriminated. +Hint Extern 1 => progress autorewrite with push_decode in * : push_decode. +Hint Extern 1 => progress autorewrite with pull_decode in * : pull_decode. -Local Coercion Z.of_nat : nat >-> Z. +Class decoder (n : Z) W := + { decode : W -> Z }. +Coercion decode : decoder >-> Funclass. +Global Arguments decode {n W _} _. + +Class is_decode {n W} (decode : decoder n W) := + decode_range : forall x, 0 <= decode x < 2^n. Section InstructionGallery. - Context (n:Z) (* bit-width of width of W *) - {W : Type}(* previously [BoundedType], [W] for word *) - (decode : W -> Z). + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) - Class is_load_immediate (ldi:imm->W) := + Record load_immediate := { ldi :> imm -> W }. + + Class is_load_immediate {ldi : load_immediate} := decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. - Class is_shift_right_doubleword_immediate (shrd:W->W->imm->W) := + Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }. + + Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := decode_shift_right_doubleword : forall high low count, 0 <= count < n -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. - Class is_shift_left_immediate (shl:W->imm->W) := + Record shift_left_immediate := { shl :> W -> imm -> W }. + + Class is_shift_left_immediate (shl : shift_left_immediate) := decode_shift_left_immediate : forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. - Class is_spread_left_immediate (sprl:W->imm->W*W(*high, low*)) := + Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }. + + Class is_spread_left_immediate (sprl : spread_left_immediate) := { - fst_spread_left_immediate : forall r count, 0 <= count < n -> - decode (fst (sprl r count)) = (decode r << count) >> n; - snd_spread_left_immediate : forall r count, 0 <= count < n -> - decode (snd (sprl r count)) = (decode r << count) mod 2^n + decode_fst_spread_left_immediate : forall r count, + 0 <= count < n + -> decode (fst (sprl r count)) = (decode r << count) >> n; + decode_snd_spread_left_immediate : forall r count, + 0 <= count < n + -> decode (snd (sprl r count)) = (decode r << count) mod 2^n }. - Class is_mask_keep_low (mkl:W->imm->W) := + Record mask_keep_low := { mkl :> W -> imm -> W }. + + Class is_mask_keep_low (mkl : mask_keep_low) := decode_mask_keep_low : forall r count, 0 <= count < n -> decode (mkl r count) = decode r mod 2^count. Local Notation bit b := (if b then 1 else 0). - Class is_add_with_carry (adc:W->W->bool->bool*W) := + + Record add_with_carry := { adc :> W -> W -> bool -> bool * W }. + + Class is_add_with_carry (adc : add_with_carry) := { - fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; - snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) + bit_fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) }. + Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. + Class is_sub_with_carry (subc:W->W->bool->bool*W) := { - fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) W->W) := - decode_mul : forall x y, decode (mul x y) = decode x * decode y mod 2^n. + Record multiply := { mul :> W -> W -> W }. + + Class is_mul (mul : multiply) := + decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n. + + Record multiply_low_low := { mulhwll :> W -> W -> W }. + Record multiply_high_low := { mulhwhl :> W -> W -> W }. + Record multiply_high_high := { mulhwhh :> W -> W -> W }. - Class is_mul_low_low (w:Z) (mulhwll:W->W->W) := + Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := decode_mul_low_low : forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. - Class is_mul_high_low (w:Z) (mulhwhl:W->W->W) := + Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := decode_mul_high_low : forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n. - Class is_mul_high_high (w:Z) (mulhwhh:W->W->W) := + Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := decode_mul_high_high : forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. + + Record select_conditional := { selc :> bool -> W -> W -> W }. + + Class is_select_conditional (selc : select_conditional) := + decode_select_conditional : forall b x y, + decode (selc b x y) = if b then decode x else decode y. + + Record add_modulo := { addm :> W -> W -> W (* modulus *) -> W }. + + Class is_add_modulo (addm : add_modulo) := + decode_add_modulo : forall x y modulus, + decode (addm x y modulus) = (decode x + decode y) mod (decode modulus). End InstructionGallery. +Global Arguments load_immediate : clear implicits. +Global Arguments shift_right_doubleword_immediate : clear implicits. +Global Arguments shift_left_immediate : clear implicits. +Global Arguments spread_left_immediate : clear implicits. +Global Arguments mask_keep_low : clear implicits. +Global Arguments add_with_carry : clear implicits. +Global Arguments sub_with_carry : clear implicits. +Global Arguments multiply : clear implicits. +Global Arguments multiply_low_low : clear implicits. +Global Arguments multiply_high_low : clear implicits. +Global Arguments multiply_high_high : clear implicits. +Global Arguments select_conditional : clear implicits. +Global Arguments add_modulo : clear implicits. +Global Arguments ldi {_ _} _. +Global Arguments shrd {_ _} _ _ _. +Global Arguments shl {_ _} _ _. +Global Arguments sprl {_ _} _ _. +Global Arguments mkl {_ _} _ _. +Global Arguments adc {_ _} _ _ _. +Global Arguments subc {_ _} _ _ _. +Global Arguments mul {_ _} _ _. +Global Arguments mulhwll {_ _} _ _. +Global Arguments mulhwhl {_ _} _ _. +Global Arguments mulhwhh {_ _} _ _. +Global Arguments selc {_ _} _ _ _. +Global Arguments addm {_ _} _ _ _. + +Existing Class load_immediate. +Existing Class shift_right_doubleword_immediate. +Existing Class shift_left_immediate. +Existing Class spread_left_immediate. +Existing Class mask_keep_low. +Existing Class add_with_carry. +Existing Class sub_with_carry. +Existing Class multiply. +Existing Class multiply_low_low. +Existing Class multiply_high_low. +Existing Class multiply_high_high. +Existing Class select_conditional. +Existing Class add_modulo. + +Global Arguments is_decode {_ _} _. +Global Arguments is_load_immediate {_ _ _} _. +Global Arguments is_shift_right_doubleword_immediate {_ _ _} _. +Global Arguments is_shift_left_immediate {_ _ _} _. +Global Arguments is_spread_left_immediate {_ _ _} _. +Global Arguments is_mask_keep_low {_ _ _} _. +Global Arguments is_add_with_carry {_ _ _} _. +Global Arguments is_sub_with_carry {_ _ _} _. +Global Arguments is_mul {_ _ _} _. +Global Arguments is_mul_low_low {_ _ _} _ _. +Global Arguments is_mul_high_low {_ _ _} _ _. +Global Arguments is_mul_high_high {_ _ _} _ _. +Global Arguments is_select_conditional {_ _ _} _. +Global Arguments is_add_modulo {_ _ _} _. + +Ltac bounded_sovlver_tac := + solve [ eassumption | typeclasses eauto | omega | auto 6 using decode_range with typeclass_instances omega ]. + +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo using bounded_sovlver_tac : push_decode. + +Ltac push_decode := + repeat first [ erewrite !decode_load_immediate by bounded_sovlver_tac + | erewrite !decode_shift_right_doubleword by bounded_sovlver_tac + | erewrite !decode_shift_left_immediate by bounded_sovlver_tac + | erewrite !decode_fst_spread_left_immediate by bounded_sovlver_tac + | erewrite !decode_snd_spread_left_immediate by bounded_sovlver_tac + | erewrite !decode_mask_keep_low by bounded_sovlver_tac + | erewrite !bit_fst_add_with_carry by bounded_sovlver_tac + | erewrite !decode_snd_add_with_carry by bounded_sovlver_tac + | erewrite !fst_sub_with_carry by bounded_sovlver_tac + | erewrite !decode_snd_sub_with_carry by bounded_sovlver_tac + | erewrite !decode_mul by bounded_sovlver_tac + | erewrite !decode_mul_low_low by bounded_sovlver_tac + | erewrite !decode_mul_high_low by bounded_sovlver_tac + | erewrite !decode_mul_high_high by bounded_sovlver_tac + | erewrite !decode_select_conditional by bounded_sovlver_tac + | erewrite !decode_add_modulo by bounded_sovlver_tac ]. +Ltac pull_decode := + repeat first [ erewrite <- !decode_load_immediate by bounded_sovlver_tac + | erewrite <- !decode_shift_right_doubleword by bounded_sovlver_tac + | erewrite <- !decode_shift_left_immediate by bounded_sovlver_tac + | erewrite <- !decode_fst_spread_left_immediate by bounded_sovlver_tac + | erewrite <- !decode_snd_spread_left_immediate by bounded_sovlver_tac + | erewrite <- !decode_mask_keep_low by bounded_sovlver_tac + | erewrite <- !bit_fst_add_with_carry by bounded_sovlver_tac + | erewrite <- !decode_snd_add_with_carry by bounded_sovlver_tac + | erewrite <- !fst_sub_with_carry by bounded_sovlver_tac + | erewrite <- !decode_snd_sub_with_carry by bounded_sovlver_tac + | erewrite <- !decode_mul by bounded_sovlver_tac + | erewrite <- !decode_mul_low_low by bounded_sovlver_tac + | erewrite <- !decode_mul_high_low by bounded_sovlver_tac + | erewrite <- !decode_mul_high_high by bounded_sovlver_tac + | erewrite <- !decode_select_conditional by bounded_sovlver_tac + | erewrite <- !decode_add_modulo by bounded_sovlver_tac ]. + Module fancy_machine. Local Notation imm := Z (only parsing). - Class instructions (n:Z) := + + Class instructions (n : Z) := { W : Type (* [n]-bit word *); - ldi : imm -> W; - shrd : W->W->imm -> W; - sprl : W->imm -> W*W; - mkl : W->imm -> W; - adc : W->W->bool -> bool*W; - subc : W->W->bool -> bool*W + decode :> decoder n W; + ldi :> load_immediate W; + shrd :> shift_right_doubleword_immediate W; + shl :> shift_left_immediate W; + mkl :> mask_keep_low W; + adc :> add_with_carry W; + subc :> sub_with_carry W; + mulhwll :> multiply_low_low W; + mulhwhl :> multiply_high_low W; + mulhwhh :> multiply_high_high W; + selc :> select_conditional W; + addm :> add_modulo W }. - Class arithmetic {n} (ops:instructions n) := + Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) := { - decode : W -> Z; - decode_range : forall x, 0 <= decode x < 2^n; - load_immediate : is_load_immediate n decode ldi; - shift_right_doubleword_immediate : is_shift_right_doubleword_immediate n decode shrd; - spread_left_immediate : is_spread_left_immediate n decode sprl; - mask_keep_low : is_mask_keep_low n decode mkl; - add_with_carry : is_add_with_carry n decode adc; - sub_with_carry : is_sub_with_carry n decode subc + decode_range :> is_decode decode; + load_immediate :> is_load_immediate ldi; + shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; + shift_left_immediate :> is_shift_left_immediate shl; + mask_keep_low :> is_mask_keep_low mkl; + add_with_carry :> is_add_with_carry adc; + sub_with_carry :> is_sub_with_carry subc; + multiply_low_low :> is_mul_low_low n_over_two mulhwll; + multiply_high_low :> is_mul_high_low n_over_two mulhwhl; + multiply_high_high :> is_mul_high_high n_over_two mulhwhh; + select_conditional :> is_select_conditional selc; + add_modulo :> is_add_modulo addm }. - Global Existing Instance load_immediate. - Global Existing Instance shift_right_doubleword_immediate. - Global Existing Instance spread_left_immediate. - Global Existing Instance mask_keep_low. - Global Existing Instance add_with_carry. - Global Existing Instance sub_with_carry. -End fancy_machine. \ No newline at end of file +End fancy_machine. -- cgit v1.2.3 From 975421705a8e1196cd04c2fc396284bdbd857de7 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 12 Aug 2016 14:06:18 -0700 Subject: Add TODO --- src/BoundedArithmetic/DoubleBounded.v | 1 + src/BoundedArithmetic/Interface.v | 9 +++++++++ 2 files changed, 10 insertions(+) (limited to 'src') diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index c09006e2b..7fa0d4db1 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -3,6 +3,7 @@ Require Import Coq.ZArith.ZArith Coq.Lists.List. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. +Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index cda72967c..4a14a160b 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -11,6 +11,15 @@ Create HintDb pull_decode discriminated. Hint Extern 1 => progress autorewrite with push_decode in * : push_decode. Hint Extern 1 => progress autorewrite with pull_decode in * : pull_decode. +(* TODO(from jgross): Try dropping the record wrappers. See + https://github.com/mit-plv/fiat-crypto/pull/52#discussion_r74627992 + and + https://github.com/mit-plv/fiat-crypto/pull/52#discussion_r74658417 + and + https://github.com/mit-plv/fiat-crypto/pull/52#issuecomment-239536847. + The wrappers are here to make [autorewrite] databases feasable and + fast, based on design patterns learned from past experience. There + might be better ways. *) Class decoder (n : Z) W := { decode : W -> Z }. Coercion decode : decoder >-> Funclass. -- cgit v1.2.3 From f9f012e80be8e1a49b8509c0f8c2410b32d53920 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 12 Aug 2016 15:37:28 -0700 Subject: Fix some things --- src/BoundedArithmetic/Interface.v | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 4a14a160b..6ad38288a 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -36,7 +36,7 @@ Section InstructionGallery. Record load_immediate := { ldi :> imm -> W }. - Class is_load_immediate {ldi : load_immediate} := + Class is_load_immediate {ldi : load_immediate} := decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }. @@ -83,7 +83,7 @@ Section InstructionGallery. Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. - Class is_sub_with_carry (subc:W->W->bool->bool*W) := + Class is_sub_with_carry (subc : sub_with_carry) := { fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) Date: Fri, 12 Aug 2016 14:53:00 -0700 Subject: Add _valid properties --- src/BoundedArithmetic/DoubleBounded.v | 33 +++++++ src/BoundedArithmetic/Interface.v | 171 +++++++++++++++++++++++++--------- 2 files changed, 160 insertions(+), 44 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 7fa0d4db1..ac3748f20 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -4,6 +4,7 @@ Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. @@ -16,7 +17,39 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. Local Notation eta x := (fst x, snd x). +(** TODO(jgross): Split off proofs *) Section generic_constructions. + Section decode. + Context {n W} {decode : decoder n W} {k : Z} + {validity : word_validity n W}. + Let limb_widths := repeat n (Z.to_nat k). + (** The list is low to high *) + Global Instance list_decoder : decoder (k * n) (list W) + := { decode w := BaseSystem.decode (base_from_limb_widths limb_widths) (List.map decode w) }. + + Global Instance list_word_validity + : word_validity (k * n) (list W) + := { word_valid w := 0 <= k + /\ List.length w = Z.to_nat k + /\ forall i v, nth_error w i = Some v -> word_valid (nth_default (List.fold_right and True (List.map word_valid w) }. + + Global Instance list_is_decode : is_decode list_decoder. + Proof. + unfold list_decoder, list_word_validity; hnf; simpl. + Print bounded. + SearchAbout (BaseSystem.decode _ _). + +About Z.shiftr_spec. + hnf. + simpl. + decode_range : forall x, word_valid x -> 0 <= decode x < 2^n. + + + + + + + Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) (xs ys : list T) (carry : bool) : bool * list T := List.fold_right diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 6ad38288a..ebd0f302f 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -25,51 +25,82 @@ Class decoder (n : Z) W := Coercion decode : decoder >-> Funclass. Global Arguments decode {n W _} _. -Class is_decode {n W} (decode : decoder n W) := - decode_range : forall x, 0 <= decode x < 2^n. +Class word_validity (n : Z) W := + { word_valid : W -> Prop }. +Existing Class word_valid. +Global Arguments word_valid {n W _} _. + +Class is_decode {n W} (decode : decoder n W) {validity : word_validity n W} := + decode_range : forall x, word_valid x -> 0 <= decode x < 2^n. Section InstructionGallery. Context (n : Z) (* bit-width of width of [W] *) {W : Type} (* bounded type, [W] for word *) - (Wdecoder : decoder n W). + (Wdecoder : decoder n W) + {validity : word_validity n W}. Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) Record load_immediate := { ldi :> imm -> W }. Class is_load_immediate {ldi : load_immediate} := - decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. + { + load_immediate_valid :> forall x, 0 <= x < 2^n -> word_valid (ldi x); + decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x + }. Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }. Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := - decode_shift_right_doubleword : - forall high low count, - 0 <= count < n - -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. + { + shift_right_doubleword_valid :> forall high low count, + word_valid high -> word_valid low -> 0 <= count < n + -> word_valid (shrd high low count); + decode_shift_right_doubleword : forall high low count, + word_valid high -> word_valid low -> 0 <= count < n + -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n + }. Record shift_left_immediate := { shl :> W -> imm -> W }. Class is_shift_left_immediate (shl : shift_left_immediate) := - decode_shift_left_immediate : - forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. + { + shift_left_immediate_valid :> forall r count, + word_valid r -> 0 <= count < n + -> word_valid (shl r count); + decode_shift_left_immediate : forall r count, + word_valid r -> 0 <= count < n + -> decode (shl r count) = (decode r << count) mod 2^n + }. Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }. Class is_spread_left_immediate (sprl : spread_left_immediate) := { + fst_spread_left_immediate_valid : forall r count, + word_valid r -> 0 <= count < n + -> word_valid (fst (sprl r count)); + snd_spread_left_immediate_valid : forall r count, + word_valid r -> 0 <= count < n + -> word_valid (snd (sprl r count)); decode_fst_spread_left_immediate : forall r count, - 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) >> n; + word_valid r -> 0 <= count < n + -> decode (fst (sprl r count)) = (decode r << count) >> n; decode_snd_spread_left_immediate : forall r count, - 0 <= count < n + word_valid r -> 0 <= count < n -> decode (snd (sprl r count)) = (decode r << count) mod 2^n }. Record mask_keep_low := { mkl :> W -> imm -> W }. Class is_mask_keep_low (mkl : mask_keep_low) := - decode_mask_keep_low : forall r count, - 0 <= count < n -> decode (mkl r count) = decode r mod 2^count. + { + mask_keep_low_valid :> forall r count, + word_valid r -> 0 <= count < n + -> word_valid (mkl r count); + decode_mask_keep_low : forall r count, + word_valid r -> 0 <= count < n + -> decode (mkl r count) = decode r mod 2^count + }. Local Notation bit b := (if b then 1 else 0). @@ -77,48 +108,99 @@ Section InstructionGallery. Class is_add_with_carry (adc : add_with_carry) := { - bit_fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; - decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) + snd_add_with_carry_valid :> forall x y c, + word_valid x -> word_valid y + -> word_valid (snd (adc x y c)); + bit_fst_add_with_carry : forall x y c, + word_valid x -> word_valid y + -> bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry : forall x y c, + word_valid x -> word_valid y + -> decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) }. Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. Class is_sub_with_carry (subc : sub_with_carry) := { - fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) forall x y c, + word_valid x -> word_valid y + -> word_valid (snd (subc x y c)); + fst_sub_with_carry : forall x y c, + word_valid x -> word_valid y + -> fst (subc x y c) = ((decode x - decode y - bit c) word_valid y + -> decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n }. Record multiply := { mul :> W -> W -> W }. Class is_mul (mul : multiply) := - decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n. + { + mul_valid :> forall x y, + word_valid x -> word_valid y + -> word_valid (mul x y); + decode_mul : forall x y, + word_valid x -> word_valid y + -> decode (mul x y) = (decode x * decode y) mod 2^n + }. Record multiply_low_low := { mulhwll :> W -> W -> W }. Record multiply_high_low := { mulhwhl :> W -> W -> W }. Record multiply_high_high := { mulhwhh :> W -> W -> W }. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := - decode_mul_low_low : - forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + { + mul_low_low_valid :> forall x y, + word_valid x -> word_valid y + -> word_valid (mulhwll x y); + decode_mul_low_low : forall x y, + word_valid x -> word_valid y + -> decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n + }. Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := - decode_mul_high_low : - forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + { + mul_high_low_valid :> forall x y, + word_valid x -> word_valid y + -> word_valid (mulhwhl x y); + decode_mul_high_low : forall x y, + word_valid x -> word_valid y + -> decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n + }. Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := - decode_mul_high_high : - forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. + { + mul_high_high_valid :> forall x y, + word_valid x -> word_valid y + -> word_valid (mulhwhh x y); + decode_mul_high_high : forall x y, + word_valid x -> word_valid y + -> decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n + }. Record select_conditional := { selc :> bool -> W -> W -> W }. Class is_select_conditional (selc : select_conditional) := - decode_select_conditional : forall b x y, - decode (selc b x y) = if b then decode x else decode y. + { + select_conditional_valid : forall b x y, + word_valid x -> word_valid y + -> word_valid (selc b x y); + decode_select_conditional : forall b x y, + word_valid x -> word_valid y + -> decode (selc b x y) = if b then decode x else decode y + }. Record add_modulo := { addm :> W -> W -> W (* modulus *) -> W }. Class is_add_modulo (addm : add_modulo) := - decode_add_modulo : forall x y modulus, - decode (addm x y modulus) = (decode x + decode y) mod (decode modulus). + { + add_modulo_valid : forall x y modulus, + word_valid x -> word_valid y -> word_valid modulus + -> word_valid (addm x y modulus); + decode_add_modulo : forall x y modulus, + word_valid x -> word_valid y -> word_valid modulus + -> decode (addm x y modulus) = (decode x + decode y) mod (decode modulus) + }. End InstructionGallery. Global Arguments load_immediate : clear implicits. @@ -162,20 +244,20 @@ Existing Class multiply_high_high. Existing Class select_conditional. Existing Class add_modulo. -Global Arguments is_decode {_ _} _. -Global Arguments is_load_immediate {_ _ _} _. -Global Arguments is_shift_right_doubleword_immediate {_ _ _} _. -Global Arguments is_shift_left_immediate {_ _ _} _. -Global Arguments is_spread_left_immediate {_ _ _} _. -Global Arguments is_mask_keep_low {_ _ _} _. -Global Arguments is_add_with_carry {_ _ _} _. -Global Arguments is_sub_with_carry {_ _ _} _. -Global Arguments is_mul {_ _ _} _. -Global Arguments is_mul_low_low {_ _ _} _ _. -Global Arguments is_mul_high_low {_ _ _} _ _. -Global Arguments is_mul_high_high {_ _ _} _ _. -Global Arguments is_select_conditional {_ _ _} _. -Global Arguments is_add_modulo {_ _ _} _. +Global Arguments is_decode {_ _} _ {_}. +Global Arguments is_load_immediate {_ _ _ _} _. +Global Arguments is_shift_right_doubleword_immediate {_ _ _ _} _. +Global Arguments is_shift_left_immediate {_ _ _ _} _. +Global Arguments is_spread_left_immediate {_ _ _ _} _. +Global Arguments is_mask_keep_low {_ _ _ _} _. +Global Arguments is_add_with_carry {_ _ _ _} _. +Global Arguments is_sub_with_carry {_ _ _ _} _. +Global Arguments is_mul {_ _ _ _} _. +Global Arguments is_mul_low_low {_ _ _ _} _ _. +Global Arguments is_mul_high_low {_ _ _ _} _ _. +Global Arguments is_mul_high_high {_ _ _ _} _ _. +Global Arguments is_select_conditional {_ _ _ _} _. +Global Arguments is_add_modulo {_ _ _ _} _. Ltac bounded_sovlver_tac := solve [ eassumption | typeclasses eauto | omega | auto 6 using decode_range with typeclass_instances omega ]. @@ -239,6 +321,7 @@ Module fancy_machine. Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) := { + word_validity :> word_validity (2 * n_over_two) W; decode_range :> is_decode decode; load_immediate :> is_load_immediate ldi; shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; -- cgit v1.2.3 From e7554f5525a36699fff33e70ee454cfd0a687808 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 12 Aug 2016 15:35:43 -0700 Subject: Revert "Add _valid properties" This reverts commit 4e77295a689361876b3e45262f8908d1d98c0073. --- src/BoundedArithmetic/DoubleBounded.v | 33 ------- src/BoundedArithmetic/Interface.v | 175 +++++++++------------------------- 2 files changed, 46 insertions(+), 162 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index ac3748f20..7fa0d4db1 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -4,7 +4,6 @@ Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.ModularArithmetic.Pow2BaseProofs. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. @@ -17,39 +16,7 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. Local Notation eta x := (fst x, snd x). -(** TODO(jgross): Split off proofs *) Section generic_constructions. - Section decode. - Context {n W} {decode : decoder n W} {k : Z} - {validity : word_validity n W}. - Let limb_widths := repeat n (Z.to_nat k). - (** The list is low to high *) - Global Instance list_decoder : decoder (k * n) (list W) - := { decode w := BaseSystem.decode (base_from_limb_widths limb_widths) (List.map decode w) }. - - Global Instance list_word_validity - : word_validity (k * n) (list W) - := { word_valid w := 0 <= k - /\ List.length w = Z.to_nat k - /\ forall i v, nth_error w i = Some v -> word_valid (nth_default (List.fold_right and True (List.map word_valid w) }. - - Global Instance list_is_decode : is_decode list_decoder. - Proof. - unfold list_decoder, list_word_validity; hnf; simpl. - Print bounded. - SearchAbout (BaseSystem.decode _ _). - -About Z.shiftr_spec. - hnf. - simpl. - decode_range : forall x, word_valid x -> 0 <= decode x < 2^n. - - - - - - - Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) (xs ys : list T) (carry : bool) : bool * list T := List.fold_right diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index ebd0f302f..4a14a160b 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -25,82 +25,51 @@ Class decoder (n : Z) W := Coercion decode : decoder >-> Funclass. Global Arguments decode {n W _} _. -Class word_validity (n : Z) W := - { word_valid : W -> Prop }. -Existing Class word_valid. -Global Arguments word_valid {n W _} _. - -Class is_decode {n W} (decode : decoder n W) {validity : word_validity n W} := - decode_range : forall x, word_valid x -> 0 <= decode x < 2^n. +Class is_decode {n W} (decode : decoder n W) := + decode_range : forall x, 0 <= decode x < 2^n. Section InstructionGallery. Context (n : Z) (* bit-width of width of [W] *) {W : Type} (* bounded type, [W] for word *) - (Wdecoder : decoder n W) - {validity : word_validity n W}. + (Wdecoder : decoder n W). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) Record load_immediate := { ldi :> imm -> W }. - Class is_load_immediate {ldi : load_immediate} := - { - load_immediate_valid :> forall x, 0 <= x < 2^n -> word_valid (ldi x); - decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x - }. + Class is_load_immediate {ldi : load_immediate} := + decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }. Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := - { - shift_right_doubleword_valid :> forall high low count, - word_valid high -> word_valid low -> 0 <= count < n - -> word_valid (shrd high low count); - decode_shift_right_doubleword : forall high low count, - word_valid high -> word_valid low -> 0 <= count < n - -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n - }. + decode_shift_right_doubleword : + forall high low count, + 0 <= count < n + -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. Record shift_left_immediate := { shl :> W -> imm -> W }. Class is_shift_left_immediate (shl : shift_left_immediate) := - { - shift_left_immediate_valid :> forall r count, - word_valid r -> 0 <= count < n - -> word_valid (shl r count); - decode_shift_left_immediate : forall r count, - word_valid r -> 0 <= count < n - -> decode (shl r count) = (decode r << count) mod 2^n - }. + decode_shift_left_immediate : + forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }. Class is_spread_left_immediate (sprl : spread_left_immediate) := { - fst_spread_left_immediate_valid : forall r count, - word_valid r -> 0 <= count < n - -> word_valid (fst (sprl r count)); - snd_spread_left_immediate_valid : forall r count, - word_valid r -> 0 <= count < n - -> word_valid (snd (sprl r count)); decode_fst_spread_left_immediate : forall r count, - word_valid r -> 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) >> n; + 0 <= count < n + -> decode (fst (sprl r count)) = (decode r << count) >> n; decode_snd_spread_left_immediate : forall r count, - word_valid r -> 0 <= count < n + 0 <= count < n -> decode (snd (sprl r count)) = (decode r << count) mod 2^n }. Record mask_keep_low := { mkl :> W -> imm -> W }. Class is_mask_keep_low (mkl : mask_keep_low) := - { - mask_keep_low_valid :> forall r count, - word_valid r -> 0 <= count < n - -> word_valid (mkl r count); - decode_mask_keep_low : forall r count, - word_valid r -> 0 <= count < n - -> decode (mkl r count) = decode r mod 2^count - }. + decode_mask_keep_low : forall r count, + 0 <= count < n -> decode (mkl r count) = decode r mod 2^count. Local Notation bit b := (if b then 1 else 0). @@ -108,99 +77,48 @@ Section InstructionGallery. Class is_add_with_carry (adc : add_with_carry) := { - snd_add_with_carry_valid :> forall x y c, - word_valid x -> word_valid y - -> word_valid (snd (adc x y c)); - bit_fst_add_with_carry : forall x y c, - word_valid x -> word_valid y - -> bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; - decode_snd_add_with_carry : forall x y c, - word_valid x -> word_valid y - -> decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) + bit_fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) }. Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. - Class is_sub_with_carry (subc : sub_with_carry) := + Class is_sub_with_carry (subc:W->W->bool->bool*W) := { - snd_sub_with_carry_valid :> forall x y c, - word_valid x -> word_valid y - -> word_valid (snd (subc x y c)); - fst_sub_with_carry : forall x y c, - word_valid x -> word_valid y - -> fst (subc x y c) = ((decode x - decode y - bit c) word_valid y - -> decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n + fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) W -> W -> W }. Class is_mul (mul : multiply) := - { - mul_valid :> forall x y, - word_valid x -> word_valid y - -> word_valid (mul x y); - decode_mul : forall x y, - word_valid x -> word_valid y - -> decode (mul x y) = (decode x * decode y) mod 2^n - }. + decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n. Record multiply_low_low := { mulhwll :> W -> W -> W }. Record multiply_high_low := { mulhwhl :> W -> W -> W }. Record multiply_high_high := { mulhwhh :> W -> W -> W }. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := - { - mul_low_low_valid :> forall x y, - word_valid x -> word_valid y - -> word_valid (mulhwll x y); - decode_mul_low_low : forall x y, - word_valid x -> word_valid y - -> decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n - }. + decode_mul_low_low : + forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := - { - mul_high_low_valid :> forall x y, - word_valid x -> word_valid y - -> word_valid (mulhwhl x y); - decode_mul_high_low : forall x y, - word_valid x -> word_valid y - -> decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n - }. + decode_mul_high_low : + forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := - { - mul_high_high_valid :> forall x y, - word_valid x -> word_valid y - -> word_valid (mulhwhh x y); - decode_mul_high_high : forall x y, - word_valid x -> word_valid y - -> decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n - }. + decode_mul_high_high : + forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. Record select_conditional := { selc :> bool -> W -> W -> W }. Class is_select_conditional (selc : select_conditional) := - { - select_conditional_valid : forall b x y, - word_valid x -> word_valid y - -> word_valid (selc b x y); - decode_select_conditional : forall b x y, - word_valid x -> word_valid y - -> decode (selc b x y) = if b then decode x else decode y - }. + decode_select_conditional : forall b x y, + decode (selc b x y) = if b then decode x else decode y. Record add_modulo := { addm :> W -> W -> W (* modulus *) -> W }. Class is_add_modulo (addm : add_modulo) := - { - add_modulo_valid : forall x y modulus, - word_valid x -> word_valid y -> word_valid modulus - -> word_valid (addm x y modulus); - decode_add_modulo : forall x y modulus, - word_valid x -> word_valid y -> word_valid modulus - -> decode (addm x y modulus) = (decode x + decode y) mod (decode modulus) - }. + decode_add_modulo : forall x y modulus, + decode (addm x y modulus) = (decode x + decode y) mod (decode modulus). End InstructionGallery. Global Arguments load_immediate : clear implicits. @@ -244,20 +162,20 @@ Existing Class multiply_high_high. Existing Class select_conditional. Existing Class add_modulo. -Global Arguments is_decode {_ _} _ {_}. -Global Arguments is_load_immediate {_ _ _ _} _. -Global Arguments is_shift_right_doubleword_immediate {_ _ _ _} _. -Global Arguments is_shift_left_immediate {_ _ _ _} _. -Global Arguments is_spread_left_immediate {_ _ _ _} _. -Global Arguments is_mask_keep_low {_ _ _ _} _. -Global Arguments is_add_with_carry {_ _ _ _} _. -Global Arguments is_sub_with_carry {_ _ _ _} _. -Global Arguments is_mul {_ _ _ _} _. -Global Arguments is_mul_low_low {_ _ _ _} _ _. -Global Arguments is_mul_high_low {_ _ _ _} _ _. -Global Arguments is_mul_high_high {_ _ _ _} _ _. -Global Arguments is_select_conditional {_ _ _ _} _. -Global Arguments is_add_modulo {_ _ _ _} _. +Global Arguments is_decode {_ _} _. +Global Arguments is_load_immediate {_ _ _} _. +Global Arguments is_shift_right_doubleword_immediate {_ _ _} _. +Global Arguments is_shift_left_immediate {_ _ _} _. +Global Arguments is_spread_left_immediate {_ _ _} _. +Global Arguments is_mask_keep_low {_ _ _} _. +Global Arguments is_add_with_carry {_ _ _} _. +Global Arguments is_sub_with_carry {_ _ _} _. +Global Arguments is_mul {_ _ _} _. +Global Arguments is_mul_low_low {_ _ _} _ _. +Global Arguments is_mul_high_low {_ _ _} _ _. +Global Arguments is_mul_high_high {_ _ _} _ _. +Global Arguments is_select_conditional {_ _ _} _. +Global Arguments is_add_modulo {_ _ _} _. Ltac bounded_sovlver_tac := solve [ eassumption | typeclasses eauto | omega | auto 6 using decode_range with typeclass_instances omega ]. @@ -321,7 +239,6 @@ Module fancy_machine. Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) := { - word_validity :> word_validity (2 * n_over_two) W; decode_range :> is_decode decode; load_immediate :> is_load_immediate ldi; shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; -- cgit v1.2.3 From 6897a4f42c86c4a6bfdbab6887276e7334317661 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 23 Aug 2016 15:59:35 -0700 Subject: Hook up the bounded interface, finish proofs --- _CoqProject | 1 + src/BoundedArithmetic/ArchitectureToZLike.v | 115 +------ src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 109 +++++++ src/BoundedArithmetic/DoubleBounded.v | 105 ++++++- src/BoundedArithmetic/DoubleBoundedProofs.v | 354 ++++++++++++++++++++++ src/BoundedArithmetic/Interface.v | 168 +++++++--- 6 files changed, 691 insertions(+), 161 deletions(-) create mode 100644 src/BoundedArithmetic/ArchitectureToZLikeProofs.v create mode 100644 src/BoundedArithmetic/DoubleBoundedProofs.v (limited to 'src') diff --git a/_CoqProject b/_CoqProject index 5c08e1616..caee0caff 100644 --- a/_CoqProject +++ b/_CoqProject @@ -24,6 +24,7 @@ src/Assembly/Vectorize.v src/Assembly/Wordize.v src/BoundedArithmetic/ArchitectureToZLike.v src/BoundedArithmetic/DoubleBounded.v +src/BoundedArithmetic/DoubleBoundedProofs.v src/BoundedArithmetic/Interface.v src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v src/CompleteEdwardsCurve/ExtendedCoordinates.v diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index 01387e969..e30fcfd09 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -3,123 +3,28 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. Require Import Crypto.BoundedArithmetic.DoubleBounded. Require Import Crypto.ModularArithmetic.ZBounded. -Require Import Coq.Lists.List. -Import ListNotations. +Require Import Crypto.Util.Tuple. -Local Open Scope nat_scope. Local Open Scope Z_scope. -Local Open Scope type_scope. - -Local Coercion Z.of_nat : nat >-> Z. Section fancy_machine_p256_montgomery_foundation. Context {n_over_two : Z}. - Local Notation n := (2 * n_over_two)%Z. + Local Notation n := (2 * n_over_two). Context (ops : fancy_machine.instructions n) (modulus : Z). - Definition two_list_to_tuple {A B} (x : A * list B) - := match x return match x with - | (a, [b0; b1]) => A * (B * B) - | _ => True - end - with - | (a, [b0; b1]) => (a, (b0, b1)) - | _ => I - end. -(* - (* make all machine-specific constructions here, preferrably as - thing wrappers around generic constructions *) - Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat - := { BoundedType := BoundedType * BoundedType (* [(high, low)] *); - decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z; - encode z := (encode (z / 2^n), encode (z mod 2^n)); - ShiftRight a high_low - := let '(high, low) := eta high_low in - if n <=? a then - (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high) - else - (ShiftRight a (snd high, fst low), ShiftRight a low); - ShiftLeft a high_low - := let '(high, low) := eta high_low in - if 2 * n <=? a then - let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in - let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in - ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0)) - else if n <=? a then - let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in - let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in - ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0)) - else - let '(high0, low) := eta (ShiftLeft a low) in - let '(high_high, high1) := eta (ShiftLeft a high) in - ((encode 0, high_high), (snd (CarryAdd false high0 high1), low)); - Mod2Pow a high_low - := let '(high, low) := (fst high_low, snd high_low) in - (Mod2Pow (a - n)%nat high, Mod2Pow a low); - CarryAdd carry x_high_low y_high_low - := let '(xhigh, xlow) := eta x_high_low in - let '(yhigh, ylow) := eta y_high_low in - two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]); - CarrySub carry x_high_low y_high_low - := let '(xhigh, xlow) := eta x_high_low in - let '(yhigh, ylow) := eta y_high_low in - two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }. - - Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _ - := match x with - | UpperHalf x => fst x - | LowerHalf x => snd x - end. - - Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps - {base_mops : ArchitectureBoundedFullMulOps n} - : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat := - { HalfWidthMul a b - := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }. - End single. - - Local Existing Instance DoubleArchitectureBoundedOps. - - Section full_from_half. - Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}. - - Local Infix "*" := HalfWidthMul. - - Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps - {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat} - : ArchitectureBoundedFullMulOps (2 * n)%nat := - { Mul a b - := let '(a1, a0) := (UpperHalf a, LowerHalf a) in - let '(b1, b0) := (UpperHalf b, LowerHalf b) in - let out := a0 * b0 in - let outHigh := a1 * b1 in - let tmp := a1 * b0 in - let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in - let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in - let tmp := a0 * b1 in - let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in - let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in - (outHigh, out) }. - End full_from_half. - - Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. -*) - Axiom admit : forall {T}, T. - Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := - { LargeT := fancy_machine.W * fancy_machine.W; + { LargeT := tuple fancy_machine.W 2; SmallT := fancy_machine.W; modulus_digits := ldi modulus; - decode_large := _; + decode_large := decode; decode_small := decode; - Mod_SmallBound v := snd v; - DivBy_SmallBound v := fst v; - DivBy_SmallerBound v := shrd (fst v) (snd v) smaller_bound_exp; - Mul x y := _ (*mulhwll (ldi 0, x) (ldi 0, y)*); - CarryAdd x y := _ (*adc x y false*); + Mod_SmallBound v := fst v; + DivBy_SmallBound v := snd v; + DivBy_SmallerBound v := shrd (snd v) (fst v) smaller_bound_exp; + Mul x y := mulhwll (W := tuple _ 2) (sprl x 0) (sprl y 0); + CarryAdd x y := adc x y false; CarrySubSmall x y := subc x y false; - ConditionalSubtract b x := let v := selc b (ldi 0) (ldi modulus) in snd (subc x v false); + ConditionalSubtract b x := let v := selc b (ldi modulus) (ldi 0) in snd (subc x v false); ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }. - Abort. End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v new file mode 100644 index 000000000..b7cac2bb3 --- /dev/null +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -0,0 +1,109 @@ +(*** Proving ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.BoundedArithmetic.DoubleBoundedProofs. +Require Import Crypto.BoundedArithmetic.ArchitectureToZLike. +Require Import Crypto.ModularArithmetic.ZBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil Crypto.Util.Tactics. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. + +Section fancy_machine_p256_montgomery_foundation. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two)%Z. + Context (ops : fancy_machine.instructions n) (modulus : Z). + + Local Arguments Z.mul !_ !_. + Local Arguments BaseSystem.decode !_ !_ / . + Local Arguments BaseSystem.accumulate / . + Local Arguments BaseSystem.decode' !_ !_ / . + + Local Ltac introduce_t_step := + match goal with + | [ |- forall x : bool, _ ] => intros [|] + | [ |- True -> _ ] => intros _ + | [ |- _ <= _ < _ -> _ ] => intro + | _ => let x := fresh "x" in + intro x; + try pose proof (decode_range (fst x)); + try pose proof (decode_range (snd x)); + pose proof (decode_range x) + end. + Local Ltac unfolder_t := + progress unfold LargeT, SmallT, modulus_digits, decode_large, decode_small, Mod_SmallBound, DivBy_SmallBound, DivBy_SmallerBound, Mul, CarryAdd, CarrySubSmall, ConditionalSubtract, ConditionalSubtractModulus, ZLikeOps_of_ArchitectureBoundedOps in *. + Local Ltac saturate_context_step := + match goal with + | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] + | _ => unique assert (0 <= 2 * (2 * n_over_two)) by (eauto using decode_exponent_nonnegative with typeclass_instances) + end. + Local Ltac pre_t := + repeat first [ tauto + | introduce_t_step + | unfolder_t + | saturate_context_step ]. + Local Ltac post_t_step := + match goal with + | _ => tauto + | _ => progress autorewrite with zsimplify_const in * + | _ => progress push_decode + | _ => progress autorewrite with push_Zpow in * + | _ => progress Z.rewrite_mod_small + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + | [ |- appcontext[let (a, b) := ?x in _] ] + => rewrite (surjective_pairing x); simplify_projections + | _ => progress autorewrite with Zshift_to_pow in * + | _ => progress autorewrite with simpl_tuple_decoder in * + | _ => progress autorewrite with zsimplify + | [ |- _ / ?y = _ / ?y ] => apply f_equal2; omega + | [ |- _ / _ = if _ then _ else _ ] => apply Z.div_between_0_if; auto with zarith omega + end. + Local Ltac post_t := repeat post_t_step. + Local Ltac t := pre_t; post_t. + + Global Instance ZLikeProperties_of_ArchitectureBoundedOps + {arith : fancy_machine.arithmetic ops} + (modulus_in_range : 0 <= modulus < 2^n) + (smaller_bound_exp : Z) + (smaller_bound_smaller : 0 <= smaller_bound_exp < n) + : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) + := { large_valid v := True; + medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp; + small_valid v := True }. + Proof. + (* In 8.5: *) + (* par:t. *) + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + { abstract t. } + Defined. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 7fa0d4db1..a368b96a0 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -1,13 +1,11 @@ (*** Implementing Large Bounded Arithmetic via pairs *) -Require Import Coq.ZArith.ZArith Coq.Lists.List. +Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. -Require Import Crypto.BaseSystem. -Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. -Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Notations. -Import ListNotations. Local Open Scope list_scope. Local Open Scope nat_scope. Local Open Scope Z_scope. @@ -17,6 +15,17 @@ Local Coercion Z.of_nat : nat >-> Z. Local Notation eta x := (fst x, snd x). Section generic_constructions. + Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Let limb_widths := repeat n k. + (** The list is low to high; the tuple is low to high *) + Local Instance tuple_decoder : decoder (k * n) (tuple W k) + := { decode w := BaseSystem.decode (base_from_limb_widths limb_widths) (List.map decode (List.rev (Tuple.to_list _ w))) }. + End with_k. + End decode. + Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) (xs ys : list T) (carry : bool) : bool * list T := List.fold_right @@ -27,17 +36,91 @@ Section generic_constructions. (carry, nil) (List.combine xs ys). + (** tuple is high to low ([to_list] reverses) *) + Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k + := match k return forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k with + | O => f + | S k' => fun xss yss carry => let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple' _ f k' xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)) + end. + + Definition ripple_carry_tuple {T} (f : T -> T -> bool -> bool * T) k + : forall (xs ys : tuple T k) (carry : bool), bool * tuple T k + := match k return forall (xs ys : tuple T k) (carry : bool), bool * tuple T k with + | O => fun xs ys carry => (carry, tt) + | S k' => ripple_carry_tuple' f k' + end. + Section ripple_carry_adc. Context {n W} {decode : decoder n W} (adc : add_with_carry W). - Global Instance ripple_carry_add_with_carry : add_with_carry (list W) - := {| Interface.adc := ripple_carry adc |}. - (* - Global Instance ripple_carry_is_add_with_carry {is_adc : is_add_with_carry adc} - : is_add_with_carry ripple_carry_add_with_carry.*) + Global Instance ripple_carry_adc {k} : add_with_carry (tuple W k) + := {| Interface.adc := ripple_carry_tuple adc k |}. End ripple_carry_adc. (* TODO: Would it made sense to make generic-width shift operations here? *) - (* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *) + Section tuple2. + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2 + := (shl r count, if count =? 0 then ldi 0 else shr r (n - count)). + + (** Require a [decode] instance to aid typeclass search in + resolving [n] *) + Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W + := {| Interface.sprl := spread_left_from_shift |}. + End spread_left. + + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Section def. + Context (half_n : Z). + Definition mul_double (a b : W) : tuple W 2 + := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in + let tmp := mulhwhl a b in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + let tmp := mulhwhl b a in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + out. + End def. + + Section instances. + Context {half_n : Z} + {ldi : load_immediate W}. + + (** Require a dummy [decoder] for these instances to allow + typeclass inference of the [half_n] argument *) + Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W} + : multiply_low_low (tuple W 2) + := {| Interface.mulhwll a b := mul_double half_n (fst a) (fst b) |}. + Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W} + : multiply_high_low (tuple W 2) + := {| Interface.mulhwhl a b := mul_double half_n (snd a) (fst b) |}. + Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W} + : multiply_high_high (tuple W 2) + := {| Interface.mulhwhh a b := mul_double half_n (snd a) (snd b) |}. + End instances. + End full_from_half. + End tuple2. End generic_constructions. + +Global Arguments tuple_decoder : simpl never. + +Hint Resolve (fun n W decode => (@tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2))) : typeclass_instances. +Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v new file mode 100644 index 000000000..d878a1373 --- /dev/null +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -0,0 +1,354 @@ +(*** Proofs About Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.micromega.Psatz. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BaseSystem. +Require Import Crypto.BaseSystemProofs. +Require Import Crypto.ModularArithmetic.Pow2Base. +Require Import Crypto.ModularArithmetic.Pow2BaseProofs. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Util.Tactics. +Require Import Crypto.Util.Notations. + +Import ListNotations. +Local Open Scope list_scope. +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Coercion Pos.to_nat : positive >-> nat. +Local Notation eta x := (fst x, snd x). + +Section generic_constructions. + Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Local Notation limb_widths := (repeat n k). + + Lemma decode_bounded {isdecode : is_decode decode} w + : 0 <= n -> bounded limb_widths (map decode (rev (to_list k w))). + Proof. + intro. + eapply bounded_uniform; try solve [ eauto using repeat_spec ]. + { distr_length. } + { intros z H'. + apply in_map_iff in H'. + destruct H' as [? [? H'] ]; subst; apply decode_range. } + Qed. + + (** TODO: Clean up this proof *) + Global Instance tuple_is_decode {isdecode : is_decode decode} + : is_decode (tuple_decoder (k := k)). + Proof. + unfold tuple_decoder; hnf; simpl. + intro w. + destruct (zerop k); [ subst | ]. + { unfold BaseSystem.decode, BaseSystem.decode'; simpl; omega. } + assert (0 <= n) + by (destruct k as [ | [|] ]; [ omega | | destruct w ]; + eauto using decode_exponent_nonnegative). + replace (2^(k * n)) with (upper_bound limb_widths) + by (erewrite upper_bound_uniform by eauto using repeat_spec; distr_length). + apply decode_upper_bound; auto using decode_bounded. + { intros ? H'. + apply repeat_spec in H'; omega. } + { distr_length. } + Qed. + End with_k. + + Local Arguments base_from_limb_widths : simpl never. + Local Arguments repeat : simpl never. + Local Arguments Z.mul !_ !_. + Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z. + Proof. + intro Hn. + destruct w as [? w]; simpl. + replace (decode w) with (decode w * 1 + 0)%Z by omega. + rewrite map_app, map_cons, map_nil. + erewrite decode_shift_uniform_app by (eauto using repeat_spec; distr_length). + distr_length. + autorewrite with push_skipn natsimplify push_firstn. + reflexivity. + Qed. + Lemma tuple_decoder_O w : tuple_decoder (k := 1) w = decode w. + Proof. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. + simpl. + omega. + Qed. + Lemma tuple_decoder_m1 w : tuple_decoder (k := 0) w = 0. + Proof. reflexivity. Qed. + End decode. + Local Arguments tuple_decoder : simpl never. + Local Opaque tuple_decoder. + Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. + + Hint Extern 1 (decoder _ (tuple ?W 2)) => apply (fun n decode => @tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2)) : typeclass_instances. + + Lemma ripple_carry_tuple_SS {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)). + Proof. reflexivity. Qed. + + Lemma carry_is_good (n z0 z1 k : Z) + : 0 <= n -> + 0 <= k -> + (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\ + (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n). + Proof. + intros. + assert (0 < 2 ^ n) by auto with zarith. + assert (0 < 2 ^ k) by auto with zarith. + assert (0 < 2^n * 2^k) by nia. + autorewrite with Zshift_to_pow push_Zpow. + rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia). + rewrite <- !Z.mul_mod_distr_r by lia. + rewrite !(Z.mul_comm (2^k)); pull_Zmod. + split; [ | apply f_equal2 ]; + Z.div_mod_to_quot_rem; nia. + Qed. + + Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1). + Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1). + + Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Lemma ripple_carry_adc_SS k xss yss carry + : ripple_carry_adc (k := S (S k)) adc xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in + let '(carry, z) := eta (adc x y carry) in + (carry, (zs, z)). + Proof. apply ripple_carry_tuple_SS. Qed. + + Local Existing Instance tuple_decoder. + + Global Instance ripple_carry_is_add_with_carry {k} + {isdecode : is_decode decode} + {is_adc : is_add_with_carry adc} + : is_add_with_carry (ripple_carry_adc (k := k) adc). + Proof. + destruct k as [|k]. + { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. } + { induction k as [|k IHk]. + { cbv [ripple_carry_adc ripple_carry_tuple to_list]. + constructor; simpl @fst; simpl @snd; intros; + autorewrite with simpl_tuple_decoder; + push_decode; + autorewrite with zsimplify; reflexivity. } + { apply Build_is_add_with_carry'; intros x y c. + assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). + assert (2^n <> 0) by auto with zarith. + assert (0 <= S k * n) by nia. + rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption. + simplify_projections; push_decode; generalize_decode. + erewrite carry_is_good_carry, carry_is_good_value by lia. + autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. + split; apply f_equal2; nia. } } + Qed. + + End ripple_carry_adc. + + Section tuple2. + Section spread_left_correct. + Context {n W} {decode : decoder n W} {sprl : spread_left_immediate W} + {isdecode : is_decode decode}. + Lemma is_spread_left_immediate_alt + : is_spread_left_immediate sprl + <-> (forall r count, 0 <= count < n -> tuple_decoder (k := 2) (sprl r count) = (decode r << count) mod 2^(2*n)). + Proof. + split; intro H; [ | apply Build_is_spread_left_immediate' ]; + intros r count Hc; + [ | specialize (H r count Hc); revert H ]; + pose proof (decode_range r); + assert (0 < 2^n) by auto with zarith; + assert (0 <= 2^count < 2^n) by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia); + autorewrite with simpl_tuple_decoder; push_decode; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. + End spread_left_correct. + + Local Arguments Z.pow !_ !_. + Local Arguments Z.mul !_ !_. + + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr}. + + Lemma spread_left_from_shift_correct + r count + (H : 0 < count < n) + : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod 2^(2*n))%Z. + Proof. + assert (0 <= n - count < n) by lia. + assert (0 < 2^(n-count)) by auto with zarith. + assert (2^count < 2^n) by auto with zarith. + pose proof (decode_range r). + assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith. + simpl. + push_decode; autorewrite with Zshift_to_pow zsimplify. + replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z + by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity). + rewrite Z.mul_div_eq' by lia. + autorewrite with push_Zmul zsimplify. + rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc. + repeat autorewrite with pull_Zpow zsimplify in *. + reflexivity. + Qed. + + Global Instance is_spread_left_from_shift + : is_spread_left_immediate (sprl_from_shift n). + Proof. + apply is_spread_left_immediate_alt. + intros r count; intros. + pose proof (decode_range r). + assert (0 < 2^n) by auto with zarith. + assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia). + autorewrite with simpl_tuple_decoder. + destruct (Z_zerop count). + { subst; autorewrite with Zshift_to_pow zsimplify. + simpl; push_decode. + autorewrite with push_Zpow zsimplify. + reflexivity. } + simpl. + rewrite <- spread_left_from_shift_correct by lia. + autorewrite with zsimplify Zpow_to_shift. + reflexivity. + Qed. + End spread_left. + + Local Opaque ripple_carry_adc. + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {half_n : Z} + {ldi : load_immediate W} + {decode : decoder (2 * half_n) W} + {ismulhwll : is_mul_low_low half_n mulhwll} + {ismulhwhl : is_mul_high_low half_n mulhwhl} + {ismulhwhh : is_mul_high_high half_n mulhwhh} + {isadc : is_add_with_carry adc} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isldi : is_load_immediate ldi} + {isdecode : is_decode decode}. + + Local Arguments Z.mul !_ !_. + Lemma spread_left_from_shift_half_correct + r + : (decode (shl r half_n) + decode (shr r half_n) * (2^half_n * 2^half_n) + = (decode r * 2^half_n) mod (2^half_n*2^half_n*2^half_n*2^half_n))%Z. + Proof. + destruct (0 (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + using solve [ auto with zarith ] + : simpl_tuple_decoder. + Local Ltac t := + hnf; intros [??] [??]; + assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative; + assert (0 <= half_n) by omega; + simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll; + rewrite decode_mul_double_mod; push_decode; autorewrite with simpl_tuple_decoder; + simpl; + push_decode; generalize_decode_var; + autorewrite with Zshift_to_pow zsimplify; + autorewrite with push_Zpow in *; Z.rewrite_mod_small; + try reflexivity. + + Global Instance mul_double_is_multiply_low_low : is_mul_low_low (2 * half_n) mul_double_multiply_low_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_low : is_mul_high_low (2 * half_n) mul_double_multiply_high_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_high : is_mul_high_high (2 * half_n) mul_double_multiply_high_high. + Proof. t. Qed. + End full_from_half. + End tuple2. +End generic_constructions. + +Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances. +Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decode (kv * n) W))) : typeclass_instances. + +Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _)) + => apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. +Hint Rewrite + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + using solve [ auto with zarith ] + : simpl_tuple_decoder. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 4a14a160b..fe64cd37e 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -53,18 +53,32 @@ Section InstructionGallery. decode_shift_left_immediate : forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. - Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }. + Record shift_right_immediate := { shr :> W -> imm -> W }. + + Class is_shift_right_immediate (shr : shift_right_immediate) := + decode_shift_right_immediate : + forall r count, 0 <= count < n -> decode (shr r count) = (decode r >> count). + + Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(low, high)] *) }. Class is_spread_left_immediate (sprl : spread_left_immediate) := { decode_fst_spread_left_immediate : forall r count, - 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) >> n; - decode_snd_spread_left_immediate : forall r count, 0 <= count < n - -> decode (snd (sprl r count)) = (decode r << count) mod 2^n + -> decode (fst (sprl r count)) = (decode r << count) mod 2^n; + decode_snd_spread_left_immediate : forall r count, + 0 <= count < n + -> decode (snd (sprl r count)) = (decode r << count) >> n; + }. + Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate) + (pf : forall r count, 0 <= count < n + -> decode (fst (sprl r count)) = (decode r << count) mod 2^n + /\ decode (snd (sprl r count)) = (decode r << count) >> n) + := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); + decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. + Record mask_keep_low := { mkl :> W -> imm -> W }. Class is_mask_keep_low (mkl : mask_keep_low) := @@ -81,6 +95,11 @@ Section InstructionGallery. decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) }. + Definition Build_is_add_with_carry' (adc : add_with_carry) + (pf : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n /\ decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)) + := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); + decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. + Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. Class is_sub_with_carry (subc:W->W->bool->bool*W) := @@ -89,6 +108,11 @@ Section InstructionGallery. decode_snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n }. + Definition Build_is_sub_with_carry' (subc : sub_with_carry) + (pf : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) W -> W -> W }. Class is_mul (mul : multiply) := @@ -118,12 +142,15 @@ Section InstructionGallery. Class is_add_modulo (addm : add_modulo) := decode_add_modulo : forall x y modulus, - decode (addm x y modulus) = (decode x + decode y) mod (decode modulus). + decode (addm x y modulus) = (if (decode x + decode y) Z) + : @decode n W {| decode := dec |} = dec. +Proof. reflexivity. Qed. + +Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} + (isinhabited : W) + : 0 <= n. +Proof. + pose proof (decode_range isinhabited). + assert (0 < 2^n) by omega. + destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ]. + assert (2^n = 0) by auto using Z.pow_neg_r. + omega. +Qed. + +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo @decode_proj using bounded_solver_tac : push_decode. + +Ltac push_decode_step := + first [ rewrite !decode_proj + | erewrite !decode_load_immediate by bounded_solver_tac + | erewrite !decode_shift_right_doubleword by bounded_solver_tac + | erewrite !decode_shift_left_immediate by bounded_solver_tac + | erewrite !decode_shift_right_immediate by bounded_solver_tac + | erewrite !decode_fst_spread_left_immediate by bounded_solver_tac + | erewrite !decode_snd_spread_left_immediate by bounded_solver_tac + | erewrite !decode_mask_keep_low by bounded_solver_tac + | erewrite !bit_fst_add_with_carry by bounded_solver_tac + | erewrite !decode_snd_add_with_carry by bounded_solver_tac + | erewrite !fst_sub_with_carry by bounded_solver_tac + | erewrite !decode_snd_sub_with_carry by bounded_solver_tac + | erewrite !decode_mul by bounded_solver_tac + | erewrite !decode_mul_low_low by bounded_solver_tac + | erewrite !decode_mul_high_low by bounded_solver_tac + | erewrite !decode_mul_high_high by bounded_solver_tac + | erewrite !decode_select_conditional by bounded_solver_tac + | erewrite !decode_add_modulo by bounded_solver_tac ]. +Ltac pull_decode_step := + first [ erewrite <- !decode_load_immediate by bounded_solver_tac + | erewrite <- !decode_shift_right_doubleword by bounded_solver_tac + | erewrite <- !decode_shift_left_immediate by bounded_solver_tac + | erewrite <- !decode_shift_right_immediate by bounded_solver_tac + | erewrite <- !decode_fst_spread_left_immediate by bounded_solver_tac + | erewrite <- !decode_snd_spread_left_immediate by bounded_solver_tac + | erewrite <- !decode_mask_keep_low by bounded_solver_tac + | erewrite <- !bit_fst_add_with_carry by bounded_solver_tac + | erewrite <- !decode_snd_add_with_carry by bounded_solver_tac + | erewrite <- !fst_sub_with_carry by bounded_solver_tac + | erewrite <- !decode_snd_sub_with_carry by bounded_solver_tac + | erewrite <- !decode_mul by bounded_solver_tac + | erewrite <- !decode_mul_low_low by bounded_solver_tac + | erewrite <- !decode_mul_high_low by bounded_solver_tac + | erewrite <- !decode_mul_high_high by bounded_solver_tac + | erewrite <- !decode_select_conditional by bounded_solver_tac + | erewrite <- !decode_add_modulo by bounded_solver_tac ]. +Ltac push_decode := repeat push_decode_step. +Ltac pull_decode := repeat pull_decode_step. + +(* We take special care to handle the case where the decoder is + syntactically different but the decoded expression is judgmentally + the same; we don't want to split apart variables that should be the + same. *) +Ltac set_decode_step check := + match goal with + | [ |- context G[@Interface.decode ?n ?W ?dr ?w] ] + => check w; + first [ match goal with + | [ d := @Interface.decode _ _ _ w |- _ ] + => change (@Interface.decode n W dr w) with d + end + | generalize (@decode_range n W dr _ w); + let d := fresh "d" in + set (d := @Interface.decode n W dr w); + intro ] + end. +Ltac set_decode check := repeat set_decode_step check. +Ltac clearbody_decode := + repeat match goal with + | [ H := @Interface.decode _ _ _ _ |- _ ] => clearbody H + end. +Ltac generalize_decode_by check := set_decode check; clearbody_decode. +Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). +Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w). Module fancy_machine. Local Notation imm := Z (only parsing). @@ -227,6 +303,7 @@ Module fancy_machine. ldi :> load_immediate W; shrd :> shift_right_doubleword_immediate W; shl :> shift_left_immediate W; + shr :> shift_right_immediate W; mkl :> mask_keep_low W; adc :> add_with_carry W; subc :> sub_with_carry W; @@ -243,6 +320,7 @@ Module fancy_machine. load_immediate :> is_load_immediate ldi; shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; shift_left_immediate :> is_shift_left_immediate shl; + shift_right_immediate :> is_shift_right_immediate shr; mask_keep_low :> is_mask_keep_low mkl; add_with_carry :> is_add_with_carry adc; sub_with_carry :> is_sub_with_carry subc; -- cgit v1.2.3 From f9f4aa9629e1e9ad82095d9b6600d1645351873c Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 23 Aug 2016 16:27:15 -0700 Subject: Weaken the condition on smaller_bound --- src/BoundedArithmetic/ArchitectureToZLike.v | 4 +++- src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 7 ++++++- src/BoundedArithmetic/Interface.v | 8 +++++++- 3 files changed, 16 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index e30fcfd09..cd221c10d 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -21,7 +21,9 @@ Section fancy_machine_p256_montgomery_foundation. decode_small := decode; Mod_SmallBound v := fst v; DivBy_SmallBound v := snd v; - DivBy_SmallerBound v := shrd (snd v) (fst v) smaller_bound_exp; + DivBy_SmallerBound v := if smaller_bound_exp =? n + then snd v + else shrd (snd v) (fst v) smaller_bound_exp; Mul x y := mulhwll (W := tuple _ 2) (sprl x 0) (sprl y 0); CarryAdd x y := adc x y false; CarrySubSmall x y := subc x y false; diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v index b7cac2bb3..0d19a54bc 100644 --- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -55,6 +55,7 @@ Section fancy_machine_p256_montgomery_foundation. | _ => progress push_decode | _ => progress autorewrite with push_Zpow in * | _ => progress Z.rewrite_mod_small + | _ => progress subst | [ |- fst ?x = (?a <=? ?b) :> bool ] => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); [ destruct (fst x), (a <=? b); intro; congruence | ] @@ -63,8 +64,11 @@ Section fancy_machine_p256_montgomery_foundation. | _ => progress autorewrite with Zshift_to_pow in * | _ => progress autorewrite with simpl_tuple_decoder in * | _ => progress autorewrite with zsimplify + | [ H : (_ =? _) = true |- _ ] => apply Z.eqb_eq in H + | [ H : (_ =? _) = false |- _ ] => apply Z.eqb_neq in H | [ |- _ / ?y = _ / ?y ] => apply f_equal2; omega | [ |- _ / _ = if _ then _ else _ ] => apply Z.div_between_0_if; auto with zarith omega + | [ |- context[if ?x =? ?y then _ else _] ] => destruct (x =? y) eqn:? end. Local Ltac post_t := repeat post_t_step. Local Ltac t := pre_t; post_t. @@ -73,7 +77,8 @@ Section fancy_machine_p256_montgomery_foundation. {arith : fancy_machine.arithmetic ops} (modulus_in_range : 0 <= modulus < 2^n) (smaller_bound_exp : Z) - (smaller_bound_smaller : 0 <= smaller_bound_exp < n) + (smaller_bound_smaller : 0 <= smaller_bound_exp <= n) + (n_pos : 0 < n) : ZLikeProperties (ZLikeOps_of_ArchitectureBoundedOps ops modulus smaller_bound_exp) := { large_valid v := True; medium_valid v := 0 <= decode_large v < 2^n * 2^smaller_bound_exp; diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index fe64cd37e..5baa036c4 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -214,6 +214,11 @@ Lemma decode_proj n W (dec : W -> Z) : @decode n W {| decode := dec |} = dec. Proof. reflexivity. Qed. +Lemma decode_if_bool n W (decode : decoder n W) (b : bool) x y + : decode (if b then x else y) + = if b then decode x else decode y. +Proof. destruct b; reflexivity. Qed. + Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} (isinhabited : W) : 0 <= n. @@ -225,10 +230,11 @@ Proof. omega. Qed. -Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo @decode_proj using bounded_solver_tac : push_decode. +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo @decode_proj @decode_if_bool using bounded_solver_tac : push_decode. Ltac push_decode_step := first [ rewrite !decode_proj + | rewrite !decode_if_bool | erewrite !decode_load_immediate by bounded_solver_tac | erewrite !decode_shift_right_doubleword by bounded_solver_tac | erewrite !decode_shift_left_immediate by bounded_solver_tac -- cgit v1.2.3 From ab8d4c062bb998e4746de774889d4e861b401382 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 24 Aug 2016 13:34:59 -0700 Subject: Coq 8.4 fixes --- src/BoundedArithmetic/DoubleBounded.v | 10 ++--- src/BoundedArithmetic/Interface.v | 69 +++++++++++++++++------------------ 2 files changed, 39 insertions(+), 40 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index a368b96a0..cb8a21495 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -59,7 +59,7 @@ Section generic_constructions. Context {n W} {decode : decoder n W} (adc : add_with_carry W). Global Instance ripple_carry_adc {k} : add_with_carry (tuple W k) - := {| Interface.adc := ripple_carry_tuple adc k |}. + := { adc := ripple_carry_tuple adc k }. End ripple_carry_adc. (* TODO: Would it made sense to make generic-width shift operations here? *) @@ -77,7 +77,7 @@ Section generic_constructions. (** Require a [decode] instance to aid typeclass search in resolving [n] *) Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W - := {| Interface.sprl := spread_left_from_shift |}. + := { sprl := spread_left_from_shift }. End spread_left. Section full_from_half. @@ -108,13 +108,13 @@ Section generic_constructions. typeclass inference of the [half_n] argument *) Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W} : multiply_low_low (tuple W 2) - := {| Interface.mulhwll a b := mul_double half_n (fst a) (fst b) |}. + := { mulhwll a b := mul_double half_n (fst a) (fst b) }. Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W} : multiply_high_low (tuple W 2) - := {| Interface.mulhwhl a b := mul_double half_n (snd a) (fst b) |}. + := { mulhwhl a b := mul_double half_n (snd a) (fst b) }. Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W} : multiply_high_high (tuple W 2) - := {| Interface.mulhwhh a b := mul_double half_n (snd a) (snd b) |}. + := { mulhwhh a b := mul_double half_n (snd a) (snd b) }. End instances. End full_from_half. End tuple2. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 5baa036c4..4d3f7d858 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -34,12 +34,14 @@ Section InstructionGallery. (Wdecoder : decoder n W). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) - Record load_immediate := { ldi :> imm -> W }. + Class load_immediate := { ldi : imm -> W }. + Global Coercion ldi : load_immediate >-> Funclass. Class is_load_immediate {ldi : load_immediate} := decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. - Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }. + Class shift_right_doubleword_immediate := { shrd : W -> W -> imm -> W }. + Global Coercion shrd : shift_right_doubleword_immediate >-> Funclass. Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := decode_shift_right_doubleword : @@ -47,19 +49,22 @@ Section InstructionGallery. 0 <= count < n -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. - Record shift_left_immediate := { shl :> W -> imm -> W }. + Class shift_left_immediate := { shl : W -> imm -> W }. + Global Coercion shl : shift_left_immediate >-> Funclass. Class is_shift_left_immediate (shl : shift_left_immediate) := decode_shift_left_immediate : forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. - Record shift_right_immediate := { shr :> W -> imm -> W }. + Class shift_right_immediate := { shr : W -> imm -> W }. + Global Coercion shr : shift_right_immediate >-> Funclass. Class is_shift_right_immediate (shr : shift_right_immediate) := decode_shift_right_immediate : forall r count, 0 <= count < n -> decode (shr r count) = (decode r >> count). - Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(low, high)] *) }. + Class spread_left_immediate := { sprl : W -> imm -> W * W (* [(low, high)] *) }. + Global Coercion sprl : spread_left_immediate >-> Funclass. Class is_spread_left_immediate (sprl : spread_left_immediate) := { @@ -68,7 +73,7 @@ Section InstructionGallery. -> decode (fst (sprl r count)) = (decode r << count) mod 2^n; decode_snd_spread_left_immediate : forall r count, 0 <= count < n - -> decode (snd (sprl r count)) = (decode r << count) >> n; + -> decode (snd (sprl r count)) = (decode r << count) >> n }. @@ -79,7 +84,8 @@ Section InstructionGallery. := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. - Record mask_keep_low := { mkl :> W -> imm -> W }. + Class mask_keep_low := { mkl :> W -> imm -> W }. + Global Coercion mkl : mask_keep_low >-> Funclass. Class is_mask_keep_low (mkl : mask_keep_low) := decode_mask_keep_low : forall r count, @@ -87,7 +93,8 @@ Section InstructionGallery. Local Notation bit b := (if b then 1 else 0). - Record add_with_carry := { adc :> W -> W -> bool -> bool * W }. + Class add_with_carry := { adc : W -> W -> bool -> bool * W }. + Global Coercion adc : add_with_carry >-> Funclass. Class is_add_with_carry (adc : add_with_carry) := { @@ -100,7 +107,8 @@ Section InstructionGallery. := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. - Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. + Class sub_with_carry := { subc : W -> W -> bool -> bool * W }. + Global Coercion subc : sub_with_carry >-> Funclass. Class is_sub_with_carry (subc:W->W->bool->bool*W) := { @@ -113,14 +121,18 @@ Section InstructionGallery. := {| fst_sub_with_carry x y c := proj1 (pf x y c); decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}. - Record multiply := { mul :> W -> W -> W }. + Class multiply := { mul : W -> W -> W }. + Global Coercion mul : multiply >-> Funclass. Class is_mul (mul : multiply) := decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n. - Record multiply_low_low := { mulhwll :> W -> W -> W }. - Record multiply_high_low := { mulhwhl :> W -> W -> W }. - Record multiply_high_high := { mulhwhh :> W -> W -> W }. + Class multiply_low_low := { mulhwll : W -> W -> W }. + Global Coercion mulhwll : multiply_low_low >-> Funclass. + Class multiply_high_low := { mulhwhl : W -> W -> W }. + Global Coercion mulhwhl : multiply_high_low >-> Funclass. + Class multiply_high_high := { mulhwhh : W -> W -> W }. + Global Coercion mulhwhh : multiply_high_high >-> Funclass. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := decode_mul_low_low : @@ -132,13 +144,15 @@ Section InstructionGallery. decode_mul_high_high : forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. - Record select_conditional := { selc :> bool -> W -> W -> W }. + Class select_conditional := { selc : bool -> W -> W -> W }. + Global Coercion selc : select_conditional >-> Funclass. Class is_select_conditional (selc : select_conditional) := decode_select_conditional : forall b x y, decode (selc b x y) = if b then decode x else decode y. - Record add_modulo := { addm :> W -> W -> W (* modulus *) -> W }. + Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }. + Global Coercion addm : add_modulo >-> Funclass. Class is_add_modulo (addm : add_modulo) := decode_add_modulo : forall x y modulus, @@ -176,21 +190,6 @@ Global Arguments mulhwhh {_ _} _ _. Global Arguments selc {_ _} _ _ _. Global Arguments addm {_ _} _ _ _. -Existing Class load_immediate. -Existing Class shift_right_doubleword_immediate. -Existing Class shift_left_immediate. -Existing Class shift_right_immediate. -Existing Class spread_left_immediate. -Existing Class mask_keep_low. -Existing Class add_with_carry. -Existing Class sub_with_carry. -Existing Class multiply. -Existing Class multiply_low_low. -Existing Class multiply_high_low. -Existing Class multiply_high_high. -Existing Class select_conditional. -Existing Class add_modulo. - Global Arguments is_decode {_ _} _. Global Arguments is_load_immediate {_ _ _} _. Global Arguments is_shift_right_doubleword_immediate {_ _ _} _. @@ -279,21 +278,21 @@ Ltac pull_decode := repeat pull_decode_step. same. *) Ltac set_decode_step check := match goal with - | [ |- context G[@Interface.decode ?n ?W ?dr ?w] ] + | [ |- context G[@decode ?n ?W ?dr ?w] ] => check w; first [ match goal with - | [ d := @Interface.decode _ _ _ w |- _ ] - => change (@Interface.decode n W dr w) with d + | [ d := @decode _ _ _ w |- _ ] + => change (@decode n W dr w) with d end | generalize (@decode_range n W dr _ w); let d := fresh "d" in - set (d := @Interface.decode n W dr w); + set (d := @decode n W dr w); intro ] end. Ltac set_decode check := repeat set_decode_step check. Ltac clearbody_decode := repeat match goal with - | [ H := @Interface.decode _ _ _ _ |- _ ] => clearbody H + | [ H := @decode _ _ _ _ |- _ ] => clearbody H end. Ltac generalize_decode_by check := set_decode check; clearbody_decode. Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). -- cgit v1.2.3 From b90911a0b4483dc72f03d74779ef5339c252718c Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 24 Aug 2016 14:48:20 -0700 Subject: Clean up DoubleBounded --- src/BoundedArithmetic/DoubleBounded.v | 142 +++++++++++++++------------------- src/BoundedArithmetic/Interface.v | 15 +++- 2 files changed, 75 insertions(+), 82 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index cb8a21495..6173a8834 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -14,18 +14,15 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. Local Notation eta x := (fst x, snd x). -Section generic_constructions. - Section decode. - Context {n W} {decode : decoder n W}. - Section with_k. - Context {k : nat}. - Let limb_widths := repeat n k. - (** The list is low to high; the tuple is low to high *) - Local Instance tuple_decoder : decoder (k * n) (tuple W k) - := { decode w := BaseSystem.decode (base_from_limb_widths limb_widths) (List.map decode (List.rev (Tuple.to_list _ w))) }. - End with_k. - End decode. +(** The list is low to high; the tuple is low to high *) +Definition tuple_decoder {n W} {decode : decoder n W} {k : nat} : decoder (k * n) (tuple W k) + := {| decode w := BaseSystem.decode (base_from_limb_widths (repeat n k)) + (List.map decode (List.rev (Tuple.to_list _ w))) |}. +Global Arguments tuple_decoder : simpl never. +Hint Resolve (fun n W decode => (@tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2))) : typeclass_instances. +Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. +Section ripple_carry_definitions. Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) (xs ys : list T) (carry : bool) : bool * list T := List.fold_right @@ -54,73 +51,58 @@ Section generic_constructions. | O => fun xs ys carry => (carry, tt) | S k' => ripple_carry_tuple' f k' end. - - Section ripple_carry_adc. - Context {n W} {decode : decoder n W} (adc : add_with_carry W). - - Global Instance ripple_carry_adc {k} : add_with_carry (tuple W k) - := { adc := ripple_carry_tuple adc k }. - End ripple_carry_adc. - - (* TODO: Would it made sense to make generic-width shift operations here? *) - - Section tuple2. - Section spread_left. - Context (n : Z) {W} - {ldi : load_immediate W} - {shl : shift_left_immediate W} - {shr : shift_right_immediate W}. - - Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2 - := (shl r count, if count =? 0 then ldi 0 else shr r (n - count)). - - (** Require a [decode] instance to aid typeclass search in - resolving [n] *) - Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W - := { sprl := spread_left_from_shift }. - End spread_left. - - Section full_from_half. - Context {W} - {mulhwll : multiply_low_low W} - {mulhwhl : multiply_high_low W} - {mulhwhh : multiply_high_high W} - {adc : add_with_carry W} - {shl : shift_left_immediate W} - {shr : shift_right_immediate W}. - - Section def. - Context (half_n : Z). - Definition mul_double (a b : W) : tuple W 2 - := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in - let tmp := mulhwhl a b in - let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in - let tmp := mulhwhl b a in - let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in - out. - End def. - - Section instances. - Context {half_n : Z} - {ldi : load_immediate W}. - - (** Require a dummy [decoder] for these instances to allow +End ripple_carry_definitions. + +Global Instance ripple_carry_adc + {W} (adc : add_with_carry W) {k} + : add_with_carry (tuple W k) + := { adc := ripple_carry_tuple adc k }. + +(** constructions on [tuple W 2] *) +Section tuple2. + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W}. + + Definition spread_left_from_shift (r : W) (count : Z) : tuple W 2 + := (shl r count, if count =? 0 then ldi 0 else shr r (n - count)). + + (** Require a [decoder] instance to aid typeclass search in + resolving [n] *) + Global Instance sprl_from_shift {decode : decoder n W} : spread_left_immediate W + := { sprl := spread_left_from_shift }. + End spread_left. + + Section full_from_half. + Context {half_n : Z} {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {ldi : load_immediate W}. + + Definition mul_double (a b : W) : tuple W 2 + := let out : tuple W 2 := (mulhwll a b, mulhwhh a b) in + let tmp := mulhwhl a b in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + let tmp := mulhwhl b a in + let '(_, out) := eta (ripple_carry_adc adc out (shl tmp half_n, shr tmp half_n) false) in + out. + + (** Require a dummy [decoder] for these instances to allow typeclass inference of the [half_n] argument *) - Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W} - : multiply_low_low (tuple W 2) - := { mulhwll a b := mul_double half_n (fst a) (fst b) }. - Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W} - : multiply_high_low (tuple W 2) - := { mulhwhl a b := mul_double half_n (snd a) (fst b) }. - Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W} - : multiply_high_high (tuple W 2) - := { mulhwhh a b := mul_double half_n (snd a) (snd b) }. - End instances. - End full_from_half. - End tuple2. -End generic_constructions. - -Global Arguments tuple_decoder : simpl never. - -Hint Resolve (fun n W decode => (@tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2))) : typeclass_instances. -Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. + Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W} + : multiply_low_low (tuple W 2) + := { mulhwll a b := mul_double (fst a) (fst b) }. + Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W} + : multiply_high_low (tuple W 2) + := { mulhwhl a b := mul_double (snd a) (fst b) }. + Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W} + : multiply_high_high (tuple W 2) + := { mulhwhh a b := mul_double (snd a) (snd b) }. + End full_from_half. +End tuple2. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 4d3f7d858..8681fecee 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -1,6 +1,7 @@ (*** Interface for bounded arithmetic *) Require Import Coq.ZArith.ZArith. Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. Require Import Crypto.Util.Notations. Local Open Scope Z_scope. @@ -63,7 +64,7 @@ Section InstructionGallery. decode_shift_right_immediate : forall r count, 0 <= count < n -> decode (shr r count) = (decode r >> count). - Class spread_left_immediate := { sprl : W -> imm -> W * W (* [(low, high)] *) }. + Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }. Global Coercion sprl : spread_left_immediate >-> Funclass. Class is_spread_left_immediate (sprl : spread_left_immediate) := @@ -133,6 +134,8 @@ Section InstructionGallery. Global Coercion mulhwhl : multiply_high_low >-> Funclass. Class multiply_high_high := { mulhwhh : W -> W -> W }. Global Coercion mulhwhh : multiply_high_high >-> Funclass. + Class multiply_double := { muldw : W -> W -> tuple W 2 }. + Global Coercion muldw : multiply_double >-> Funclass. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := decode_mul_low_low : @@ -143,6 +146,9 @@ Section InstructionGallery. Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := decode_mul_high_high : forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. + Class is_mul_double (muldw : multiply_double) := + decode_mul_double : + forall x y, (decode (fst (muldw x y)) + decode (snd (muldw x y)) << n = decode x * decode y)%Z. Class select_conditional := { selc : bool -> W -> W -> W }. Global Coercion selc : select_conditional >-> Funclass. @@ -173,6 +179,7 @@ Global Arguments multiply : clear implicits. Global Arguments multiply_low_low : clear implicits. Global Arguments multiply_high_low : clear implicits. Global Arguments multiply_high_high : clear implicits. +Global Arguments multiply_double : clear implicits. Global Arguments select_conditional : clear implicits. Global Arguments add_modulo : clear implicits. Global Arguments ldi {_ _} _. @@ -187,6 +194,7 @@ Global Arguments mul {_ _} _ _. Global Arguments mulhwll {_ _} _ _. Global Arguments mulhwhl {_ _} _ _. Global Arguments mulhwhh {_ _} _ _. +Global Arguments muldw {_ _} _ _. Global Arguments selc {_ _} _ _ _. Global Arguments addm {_ _} _ _ _. @@ -203,6 +211,7 @@ Global Arguments is_mul {_ _ _} _. Global Arguments is_mul_low_low {_ _ _} _ _. Global Arguments is_mul_high_low {_ _ _} _ _. Global Arguments is_mul_high_high {_ _ _} _ _. +Global Arguments is_mul_double {_ _ _} _. Global Arguments is_select_conditional {_ _ _} _. Global Arguments is_add_modulo {_ _ _} _. @@ -229,7 +238,7 @@ Proof. omega. Qed. -Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo @decode_proj @decode_if_bool using bounded_solver_tac : push_decode. +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_mul_double @decode_select_conditional @decode_add_modulo @decode_proj @decode_if_bool using bounded_solver_tac : push_decode. Ltac push_decode_step := first [ rewrite !decode_proj @@ -249,6 +258,7 @@ Ltac push_decode_step := | erewrite !decode_mul_low_low by bounded_solver_tac | erewrite !decode_mul_high_low by bounded_solver_tac | erewrite !decode_mul_high_high by bounded_solver_tac + | erewrite !decode_mul_double by bounded_solver_tac | erewrite !decode_select_conditional by bounded_solver_tac | erewrite !decode_add_modulo by bounded_solver_tac ]. Ltac pull_decode_step := @@ -267,6 +277,7 @@ Ltac pull_decode_step := | erewrite <- !decode_mul_low_low by bounded_solver_tac | erewrite <- !decode_mul_high_low by bounded_solver_tac | erewrite <- !decode_mul_high_high by bounded_solver_tac + | erewrite <- !decode_mul_double by bounded_solver_tac | erewrite <- !decode_select_conditional by bounded_solver_tac | erewrite <- !decode_add_modulo by bounded_solver_tac ]. Ltac push_decode := repeat push_decode_step. -- cgit v1.2.3 From 7e44853ca592e077f5d4d110c6059f27d6e27f35 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 24 Aug 2016 14:52:34 -0700 Subject: More slight cleanups --- src/BoundedArithmetic/Interface.v | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 8681fecee..4139a91ce 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -147,8 +147,17 @@ Section InstructionGallery. decode_mul_high_high : forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. Class is_mul_double (muldw : multiply_double) := - decode_mul_double : - forall x y, (decode (fst (muldw x y)) + decode (snd (muldw x y)) << n = decode x * decode y)%Z. + { + decode_fst_mul_double : + forall x y, decode (fst (muldw x y)) = (decode x * decode y) mod 2^n; + decode_snd_mul_double : + forall x y, decode (snd (muldw x y)) = (decode x * decode y) >> n + }. + Definition Build_is_mul_double' (muldw : multiply_double) + (pf : forall x y, _ /\ _) + := {| decode_fst_mul_double x y := proj1 (pf x y); + decode_snd_mul_double x y := proj2 (pf x y) |}. + Class select_conditional := { selc : bool -> W -> W -> W }. Global Coercion selc : select_conditional >-> Funclass. @@ -238,7 +247,7 @@ Proof. omega. Qed. -Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_mul_double @decode_select_conditional @decode_add_modulo @decode_proj @decode_if_bool using bounded_solver_tac : push_decode. +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_fst_mul_double @decode_snd_mul_double @decode_select_conditional @decode_add_modulo @decode_proj @decode_if_bool using bounded_solver_tac : push_decode. Ltac push_decode_step := first [ rewrite !decode_proj @@ -258,7 +267,8 @@ Ltac push_decode_step := | erewrite !decode_mul_low_low by bounded_solver_tac | erewrite !decode_mul_high_low by bounded_solver_tac | erewrite !decode_mul_high_high by bounded_solver_tac - | erewrite !decode_mul_double by bounded_solver_tac + | erewrite !decode_fst_mul_double by bounded_solver_tac + | erewrite !decode_snd_mul_double by bounded_solver_tac | erewrite !decode_select_conditional by bounded_solver_tac | erewrite !decode_add_modulo by bounded_solver_tac ]. Ltac pull_decode_step := @@ -277,7 +287,8 @@ Ltac pull_decode_step := | erewrite <- !decode_mul_low_low by bounded_solver_tac | erewrite <- !decode_mul_high_low by bounded_solver_tac | erewrite <- !decode_mul_high_high by bounded_solver_tac - | erewrite <- !decode_mul_double by bounded_solver_tac + | erewrite <- !decode_fst_mul_double by bounded_solver_tac + | erewrite <- !decode_snd_mul_double by bounded_solver_tac | erewrite <- !decode_select_conditional by bounded_solver_tac | erewrite <- !decode_add_modulo by bounded_solver_tac ]. Ltac push_decode := repeat push_decode_step. -- cgit v1.2.3 From 34d53cc72df1a3c31838e0cc7e06f0cf8959d628 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Wed, 24 Aug 2016 19:46:19 -0700 Subject: Rework bounded proofs Now the rewrite strategy no longer relies on projections of anything other than [decode], and the conversion to ZLike is simpler. Modulo some annoyingly delicate arithmetic around things like [2^n * 2^n = 2^(2*n)] and whether to factor [(decode (fst x) + decode (snd x) >> n) >> b] as [decode x >> n] or as [shrd (fst x) (snd x) n], the proofs bascially go by pulling/pushing decodes. --- _CoqProject | 1 + src/BoundedArithmetic/ArchitectureToZLike.v | 2 +- src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 25 +- src/BoundedArithmetic/DoubleBounded.v | 29 +- src/BoundedArithmetic/DoubleBoundedProofs.v | 725 +++++++++++++--------- src/BoundedArithmetic/Interface.v | 234 +++---- src/BoundedArithmetic/InterfaceProofs.v | 202 ++++++ 7 files changed, 743 insertions(+), 475 deletions(-) create mode 100644 src/BoundedArithmetic/InterfaceProofs.v (limited to 'src') diff --git a/_CoqProject b/_CoqProject index 4f8bf08a3..9ed483860 100644 --- a/_CoqProject +++ b/_CoqProject @@ -27,6 +27,7 @@ src/BoundedArithmetic/ArchitectureToZLikeProofs.v src/BoundedArithmetic/DoubleBounded.v src/BoundedArithmetic/DoubleBoundedProofs.v src/BoundedArithmetic/Interface.v +src/BoundedArithmetic/InterfaceProofs.v src/CompleteEdwardsCurve/CompleteEdwardsCurveTheorems.v src/CompleteEdwardsCurve/ExtendedCoordinates.v src/CompleteEdwardsCurve/Pre.v diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index cd221c10d..3388ece78 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -24,7 +24,7 @@ Section fancy_machine_p256_montgomery_foundation. DivBy_SmallerBound v := if smaller_bound_exp =? n then snd v else shrd (snd v) (fst v) smaller_bound_exp; - Mul x y := mulhwll (W := tuple _ 2) (sprl x 0) (sprl y 0); + Mul x y := muldw x y; CarryAdd x y := adc x y false; CarrySubSmall x y := subc x y false; ConditionalSubtract b x := let v := selc b (ldi modulus) (ldi 0) in snd (subc x v false); diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v index 0d19a54bc..3060e17bb 100644 --- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -1,6 +1,7 @@ (*** Proving ℤ-Like via Architecture *) Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.InterfaceProofs. Require Import Crypto.BoundedArithmetic.DoubleBounded. Require Import Crypto.BoundedArithmetic.DoubleBoundedProofs. Require Import Crypto.BoundedArithmetic.ArchitectureToZLike. @@ -42,6 +43,7 @@ Section fancy_machine_p256_montgomery_foundation. | _ => unique assert (0 <= 2 * n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] | _ => unique assert (0 <= n_over_two) by solve [ eauto using decode_exponent_nonnegative with typeclass_instances | omega ] | _ => unique assert (0 <= 2 * (2 * n_over_two)) by (eauto using decode_exponent_nonnegative with typeclass_instances) + | [ H : 0 <= ?x < _ |- _ ] => unique pose proof (proj1 H); unique pose proof (proj2 H) end. Local Ltac pre_t := repeat first [ tauto @@ -50,25 +52,19 @@ Section fancy_machine_p256_montgomery_foundation. | saturate_context_step ]. Local Ltac post_t_step := match goal with - | _ => tauto - | _ => progress autorewrite with zsimplify_const in * - | _ => progress push_decode - | _ => progress autorewrite with push_Zpow in * - | _ => progress Z.rewrite_mod_small - | _ => progress subst + | _ => reflexivity + | _ => progress autorewrite with zsimplify_const | [ |- fst ?x = (?a <=? ?b) :> bool ] => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); [ destruct (fst x), (a <=? b); intro; congruence | ] - | [ |- appcontext[let (a, b) := ?x in _] ] - => rewrite (surjective_pairing x); simplify_projections - | _ => progress autorewrite with Zshift_to_pow in * - | _ => progress autorewrite with simpl_tuple_decoder in * - | _ => progress autorewrite with zsimplify - | [ H : (_ =? _) = true |- _ ] => apply Z.eqb_eq in H + | [ H : (_ =? _) = true |- _ ] => apply Z.eqb_eq in H; subst | [ H : (_ =? _) = false |- _ ] => apply Z.eqb_neq in H - | [ |- _ / ?y = _ / ?y ] => apply f_equal2; omega - | [ |- _ / _ = if _ then _ else _ ] => apply Z.div_between_0_if; auto with zarith omega + | _ => autorewrite with push_Zpow in *; solve [ reflexivity | assumption ] + | _ => autorewrite with pull_Zpow in *; pull_decode; reflexivity + | _ => progress push_decode + | _ => rewrite (Z.add_comm (_ << _) _); progress pull_decode | [ |- context[if ?x =? ?y then _ else _] ] => destruct (x =? y) eqn:? + | _ => autorewrite with Zshift_to_pow; Z.rewrite_mod_small; reflexivity end. Local Ltac post_t := repeat post_t_step. Local Ltac t := pre_t; post_t. @@ -97,6 +93,7 @@ Section fancy_machine_p256_montgomery_foundation. { abstract t. } { abstract t. } { abstract t. } +Hint Resolve Z.div_pos : zarith. { abstract t. } { abstract t. } { abstract t. } diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 6173a8834..55e46aa2b 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -1,6 +1,7 @@ (*** Implementing Large Bounded Arithmetic via pairs *) Require Import Coq.ZArith.ZArith. Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.InterfaceProofs. Require Import Crypto.ModularArithmetic.Pow2Base. Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil. @@ -19,7 +20,6 @@ Definition tuple_decoder {n W} {decode : decoder n W} {k : nat} : decoder (k * n := {| decode w := BaseSystem.decode (base_from_limb_widths (repeat n k)) (List.map decode (List.rev (Tuple.to_list _ w))) |}. Global Arguments tuple_decoder : simpl never. -Hint Resolve (fun n W decode => (@tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2))) : typeclass_instances. Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. Section ripple_carry_definitions. @@ -75,7 +75,7 @@ Section tuple2. := { sprl := spread_left_from_shift }. End spread_left. - Section full_from_half. + Section double_from_half. Context {half_n : Z} {W} {mulhwll : multiply_low_low W} {mulhwhl : multiply_high_low W} @@ -95,14 +95,19 @@ Section tuple2. (** Require a dummy [decoder] for these instances to allow typeclass inference of the [half_n] argument *) - Global Instance mul_double_multiply_low_low {decode : decoder (2 * half_n) W} - : multiply_low_low (tuple W 2) - := { mulhwll a b := mul_double (fst a) (fst b) }. - Global Instance mul_double_multiply_high_low {decode : decoder (2 * half_n) W} - : multiply_high_low (tuple W 2) - := { mulhwhl a b := mul_double (snd a) (fst b) }. - Global Instance mul_double_multiply_high_high {decode : decoder (2 * half_n) W} - : multiply_high_high (tuple W 2) - := { mulhwhh a b := mul_double (snd a) (snd b) }. - End full_from_half. + Global Instance mul_double_multiply {decode : decoder (2 * half_n) W} : multiply_double W + := { muldw a b := mul_double a b }. + End double_from_half. + + Global Instance mul_double_multiply_low_low {W} {muldw : multiply_double W} + : multiply_low_low (tuple W 2) + := { mulhwll a b := muldw (fst a) (fst b) }. + Global Instance mul_double_multiply_high_low {W} {muldw : multiply_double W} + : multiply_high_low (tuple W 2) + := { mulhwhl a b := muldw (snd a) (fst b) }. + Global Instance mul_double_multiply_high_high {W} {muldw : multiply_double W} + : multiply_high_high (tuple W 2) + := { mulhwhh a b := muldw (snd a) (snd b) }. End tuple2. + +Global Arguments mul_double half_n {_ _ _ _ _ _ _} _ _. diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v index d878a1373..b69232076 100644 --- a/src/BoundedArithmetic/DoubleBoundedProofs.v +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -1,6 +1,7 @@ (*** Proofs About Large Bounded Arithmetic via pairs *) Require Import Coq.ZArith.ZArith Coq.Lists.List Coq.micromega.Psatz. Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.InterfaceProofs. Require Import Crypto.BaseSystem. Require Import Crypto.BaseSystemProofs. Require Import Crypto.ModularArithmetic.Pow2Base. @@ -22,333 +23,449 @@ Local Coercion Z.of_nat : nat >-> Z. Local Coercion Pos.to_nat : positive >-> nat. Local Notation eta x := (fst x, snd x). -Section generic_constructions. - Section decode. - Context {n W} {decode : decoder n W}. - Section with_k. - Context {k : nat}. - Local Notation limb_widths := (repeat n k). - - Lemma decode_bounded {isdecode : is_decode decode} w - : 0 <= n -> bounded limb_widths (map decode (rev (to_list k w))). - Proof. - intro. - eapply bounded_uniform; try solve [ eauto using repeat_spec ]. - { distr_length. } - { intros z H'. - apply in_map_iff in H'. - destruct H' as [? [? H'] ]; subst; apply decode_range. } - Qed. - - (** TODO: Clean up this proof *) - Global Instance tuple_is_decode {isdecode : is_decode decode} - : is_decode (tuple_decoder (k := k)). - Proof. - unfold tuple_decoder; hnf; simpl. - intro w. - destruct (zerop k); [ subst | ]. - { unfold BaseSystem.decode, BaseSystem.decode'; simpl; omega. } - assert (0 <= n) - by (destruct k as [ | [|] ]; [ omega | | destruct w ]; - eauto using decode_exponent_nonnegative). - replace (2^(k * n)) with (upper_bound limb_widths) - by (erewrite upper_bound_uniform by eauto using repeat_spec; distr_length). - apply decode_upper_bound; auto using decode_bounded. - { intros ? H'. - apply repeat_spec in H'; omega. } - { distr_length. } - Qed. - End with_k. - - Local Arguments base_from_limb_widths : simpl never. - Local Arguments repeat : simpl never. - Local Arguments Z.mul !_ !_. - Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z. +Local Infix "==" := rewrite_eq. +Local Infix "=~>" := rewrite_left_to_right_eq. +Local Infix "<~=" := rewrite_right_to_left_eq. + +Section decode. + Context {n W} {decode : decoder n W}. + Section with_k. + Context {k : nat}. + Local Notation limb_widths := (repeat n k). + + Lemma decode_bounded {isdecode : is_decode decode} w + : 0 <= n -> bounded limb_widths (map decode (rev (to_list k w))). Proof. - intro Hn. - destruct w as [? w]; simpl. - replace (decode w) with (decode w * 1 + 0)%Z by omega. - rewrite map_app, map_cons, map_nil. - erewrite decode_shift_uniform_app by (eauto using repeat_spec; distr_length). - distr_length. - autorewrite with push_skipn natsimplify push_firstn. - reflexivity. + intro. + eapply bounded_uniform; try solve [ eauto using repeat_spec ]. + { distr_length. } + { intros z H'. + apply in_map_iff in H'. + destruct H' as [? [? H'] ]; subst; apply decode_range. } Qed. - Lemma tuple_decoder_O w : tuple_decoder (k := 1) w = decode w. + + (** TODO: Clean up this proof *) + Global Instance tuple_is_decode {isdecode : is_decode decode} + : is_decode (tuple_decoder (k := k)). Proof. - unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. - simpl. - omega. + unfold tuple_decoder; hnf; simpl. + intro w. + destruct (zerop k); [ subst | ]. + { unfold BaseSystem.decode, BaseSystem.decode'; simpl; omega. } + assert (0 <= n) + by (destruct k as [ | [|] ]; [ omega | | destruct w ]; + eauto using decode_exponent_nonnegative). + replace (2^(k * n)) with (upper_bound limb_widths) + by (erewrite upper_bound_uniform by eauto using repeat_spec; distr_length). + apply decode_upper_bound; auto using decode_bounded. + { intros ? H'. + apply repeat_spec in H'; omega. } + { distr_length. } Qed. - Lemma tuple_decoder_m1 w : tuple_decoder (k := 0) w = 0. - Proof. reflexivity. Qed. - End decode. - Local Arguments tuple_decoder : simpl never. - Local Opaque tuple_decoder. - Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. + End with_k. + + Local Arguments base_from_limb_widths : simpl never. + Local Arguments repeat : simpl never. + Local Arguments Z.mul !_ !_. + Lemma tuple_decoder_S {k} w : 0 <= n -> (tuple_decoder (k := S (S k)) w = tuple_decoder (k := S k) (fst w) + (decode (snd w) << (S k * n)))%Z. + Proof. + intro Hn. + destruct w as [? w]; simpl. + replace (decode w) with (decode w * 1 + 0)%Z by omega. + rewrite map_app, map_cons, map_nil. + erewrite decode_shift_uniform_app by (eauto using repeat_spec; distr_length). + distr_length. + autorewrite with push_skipn natsimplify push_firstn. + reflexivity. + Qed. + Global Instance tuple_decoder_O w : tuple_decoder (k := 1) w =~> decode w. + Proof. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. + simpl; hnf. + omega. + Qed. + Lemma tuple_decoder_O_ind_prod + (P : forall n, decoder n W -> Type) + (P_ext : forall n (a b : decoder n W), (forall x, a x = b x) -> P _ a -> P _ b) + : (P _ (tuple_decoder (k := 1)) -> P _ decode) + * (P _ decode -> P _ (tuple_decoder (k := 1))). + Proof. + unfold tuple_decoder, BaseSystem.decode, BaseSystem.decode', accumulate, base_from_limb_widths, repeat. + simpl; hnf. + rewrite Z.mul_1_l. + split; apply P_ext; simpl; intro; autorewrite with zsimplify_const; reflexivity. + Qed. + + Global Instance tuple_decoder_m1 w : tuple_decoder (k := 0) w =~> 0. + Proof. reflexivity. Qed. + + Global Instance tuple_decoder_2' w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. + Proof. + intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. + reflexivity. + Qed. + Global Instance tuple_decoder_2 w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. + Proof. + intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. +End decode. + +Local Arguments tuple_decoder : simpl never. +Local Opaque tuple_decoder. + +Lemma is_add_with_carry_1tuple {n W decode adc} + (H : @is_add_with_carry n W decode adc) + : @is_add_with_carry (1 * n) W (@tuple_decoder n W decode 1) adc. +Proof. + apply tuple_decoder_O_ind_prod; try assumption. + intros ??? ext [H0 H1]; apply Build_is_add_with_carry'. + intros x y c; specialize (H0 x y c); specialize (H1 x y c). + rewrite <- !ext; split; assumption. +Qed. + +Hint Extern 1 (@is_add_with_carry _ _ (@tuple_decoder ?n ?W ?decode 1) ?adc) +=> apply (@is_add_with_carry_1tuple n W decode adc) : typeclass_instances. + +Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances. +Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decoder (kv * n)%Z (tuple W k)))) : typeclass_instances. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. +Hint Rewrite Z.mul_1_l : simpl_tuple_decoder. +Hint Rewrite + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) + (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) + using solve [ auto with zarith ] + : simpl_tuple_decoder. + +Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. + +Global Instance tuple_decoder_mod {n W} {decode : decoder n W} {k} {isdecode : is_decode decode} (w : tuple W (S (S k))) + : tuple_decoder (k := S k) (fst w) <~= tuple_decoder w mod 2^(S k * n). +Proof. + pose proof (snd w). + assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= (S k) * n) by nia. + assert (0 <= tuple_decoder (k := S k) (fst w) < 2^(S k * n)) by apply decode_range. + autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify. + reflexivity. +Qed. + +Global Instance tuple_decoder_div {n W} {decode : decoder n W} {k} {isdecode : is_decode decode} (w : tuple W (S (S k))) + : decode (snd w) <~= tuple_decoder w / 2^(S k * n). +Proof. + pose proof (snd w). + assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= (S k) * n) by nia. + assert (0 <= k * n) by nia. + assert (0 < 2^n) by auto with zarith. + assert (0 <= tuple_decoder (k := S k) (fst w) < 2^(S k * n)) by apply decode_range. + autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify. + reflexivity. +Qed. + +Global Instance tuple2_decoder_mod {n W} {decode : decoder n W} {isdecode : is_decode decode} (w : tuple W 2) + : decode (fst w) <~= tuple_decoder w mod 2^n. +Proof. + generalize (@tuple_decoder_mod n W decode 0 isdecode w). + autorewrite with simpl_tuple_decoder; trivial. +Qed. + +Global Instance tuple2_decoder_div {n W} {decode : decoder n W} {isdecode : is_decode decode} (w : tuple W 2) + : decode (snd w) <~= tuple_decoder w / 2^n. +Proof. + generalize (@tuple_decoder_div n W decode 0 isdecode w). + autorewrite with simpl_tuple_decoder; trivial. +Qed. + +Lemma decode_is_spread_left_immediate_iff + {n W} + {decode : decoder n W} + {sprl : spread_left_immediate W} + {isdecode : is_decode decode} + : is_spread_left_immediate sprl + <-> (forall r count, + 0 <= count < n + -> tuple_decoder (sprl r count) = decode r << count). +Proof. + rewrite is_spread_left_immediate_alt by assumption. + split; intros H r count Hc; specialize (H r count Hc); revert H; + pose proof (decode_range r); + assert (0 < 2^count < 2^n) by auto with zarith; + autorewrite with simpl_tuple_decoder; + simpl; intro H'; rewrite H'; + autorewrite with Zshift_to_pow; + Z.rewrite_mod_small; reflexivity. +Qed. - Hint Extern 1 (decoder _ (tuple ?W 2)) => apply (fun n decode => @tuple_decoder n W decode 2 : decoder (2 * n) (tuple W 2)) : typeclass_instances. +Global Instance decode_is_spread_left_immediate + {n W} + {decode : decoder n W} + {sprl : spread_left_immediate W} + {isdecode : is_decode decode} + {issprl : is_spread_left_immediate sprl} + : forall r count, + 0 <= count < n + -> tuple_decoder (sprl r count) == decode r << count + := proj1 decode_is_spread_left_immediate_iff _. - Lemma ripple_carry_tuple_SS {T} f k xss yss carry - : @ripple_carry_tuple T f (S (S k)) xss yss carry +Lemma decode_mul_double_iff + {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + : is_mul_double muldw + <-> (forall x y, tuple_decoder (muldw x y) = (decode x * decode y)%Z). +Proof. + rewrite is_mul_double_alt by assumption. + split; intros H x y; specialize (H x y); revert H; + pose proof (decode_range x); pose proof (decode_range y); + assert (0 <= decode x * decode y < 2^n * 2^n) by nia; + assert (0 <= n) by eauto using decode_exponent_nonnegative; + autorewrite with simpl_tuple_decoder; + simpl; intro H'; rewrite H'; + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_mul_double + {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + {ismuldw : is_mul_double muldw} + : forall x y, tuple_decoder (muldw x y) == (decode x * decode y)%Z + := proj1 decode_mul_double_iff _. + +Lemma ripple_carry_tuple_SS {T} f k xss yss carry + : @ripple_carry_tuple T f (S (S k)) xss yss carry + = let '(xs, x) := eta xss in + let '(ys, y) := eta yss in + let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in + let '(carry, z) := eta (f x y carry) in + (carry, (zs, z)). +Proof. reflexivity. Qed. + +Lemma carry_is_good (n z0 z1 k : Z) + : 0 <= n -> + 0 <= k -> + (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\ + (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n). +Proof. + intros. + assert (0 < 2 ^ n) by auto with zarith. + assert (0 < 2 ^ k) by auto with zarith. + assert (0 < 2^n * 2^k) by nia. + autorewrite with Zshift_to_pow push_Zpow. + rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia). + rewrite <- !Z.mul_mod_distr_r by lia. + rewrite !(Z.mul_comm (2^k)); pull_Zmod. + split; [ | apply f_equal2 ]; + Z.div_mod_to_quot_rem; nia. +Qed. + +Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1). +Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1). + +Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Lemma ripple_carry_adc_SS k xss yss carry + : ripple_carry_adc (k := S (S k)) adc xss yss carry = let '(xs, x) := eta xss in let '(ys, y) := eta yss in - let '(carry, zs) := eta (@ripple_carry_tuple _ f (S k) xs ys carry) in - let '(carry, z) := eta (f x y carry) in + let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in + let '(carry, z) := eta (adc x y carry) in (carry, (zs, z)). - Proof. reflexivity. Qed. + Proof. apply ripple_carry_tuple_SS. Qed. - Lemma carry_is_good (n z0 z1 k : Z) - : 0 <= n -> - 0 <= k -> - (z1 + z0 >> k) >> n = (z0 + z1 << k) >> (k + n) /\ - (z0 mod 2 ^ k + ((z1 + z0 >> k) mod 2 ^ n) << k)%Z = (z0 + z1 << k) mod (2 ^ k * 2 ^ n). + Local Opaque Z.of_nat. + Global Instance ripple_carry_is_add_with_carry {k} + {isdecode : is_decode decode} + {is_adc : is_add_with_carry adc} + : is_add_with_carry (ripple_carry_adc (k := k) adc). Proof. - intros. - assert (0 < 2 ^ n) by auto with zarith. - assert (0 < 2 ^ k) by auto with zarith. - assert (0 < 2^n * 2^k) by nia. - autorewrite with Zshift_to_pow push_Zpow. - rewrite <- (Zmod_small ((z0 mod _) + _) (2^k * 2^n)) by (Z.div_mod_to_quot_rem; nia). - rewrite <- !Z.mul_mod_distr_r by lia. - rewrite !(Z.mul_comm (2^k)); pull_Zmod. - split; [ | apply f_equal2 ]; - Z.div_mod_to_quot_rem; nia. + destruct k as [|k]. + { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. } + { induction k as [|k IHk]. + { cbv [ripple_carry_adc ripple_carry_tuple to_list]. + constructor; simpl @fst; simpl @snd; intros; + simpl; pull_decode; reflexivity. } + { apply Build_is_add_with_carry'; intros x y c. + assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). + assert (2^n <> 0) by auto with zarith. + assert (0 <= S k * n) by nia. + rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption. + simplify_projections; push_decode; generalize_decode. + erewrite carry_is_good_carry, carry_is_good_value by lia. + autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. + split; apply f_equal2; nia. } } Qed. - Definition carry_is_good_carry n z0 z1 k H0 H1 := proj1 (@carry_is_good n z0 z1 k H0 H1). - Definition carry_is_good_value n z0 z1 k H0 H1 := proj2 (@carry_is_good n z0 z1 k H0 H1). +End ripple_carry_adc. + +Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _)) +=> apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances. - Section ripple_carry_adc. - Context {n W} {decode : decoder n W} (adc : add_with_carry W). +Section tuple2. + Local Arguments Z.pow !_ !_. + Local Arguments Z.mul !_ !_. - Lemma ripple_carry_adc_SS k xss yss carry - : ripple_carry_adc (k := S (S k)) adc xss yss carry - = let '(xs, x) := eta xss in - let '(ys, y) := eta yss in - let '(carry, zs) := eta (ripple_carry_adc (k := S k) adc xs ys carry) in - let '(carry, z) := eta (adc x y carry) in - (carry, (zs, z)). - Proof. apply ripple_carry_tuple_SS. Qed. + Section spread_left. + Context (n : Z) {W} + {ldi : load_immediate W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {decode : decoder n W} + {isdecode : is_decode decode} + {isldi : is_load_immediate ldi} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr}. - Local Existing Instance tuple_decoder. + Lemma spread_left_from_shift_correct + r count + (H : 0 < count < n) + : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod (2^n*2^n))%Z. + Proof. + assert (0 <= count < n) by lia. + assert (0 <= n - count < n) by lia. + assert (0 < 2^(n-count)) by auto with zarith. + assert (2^count < 2^n) by auto with zarith. + pose proof (decode_range r). + assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith. + push_decode; autorewrite with Zshift_to_pow zsimplify. + replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z + by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity). + rewrite Z.mul_div_eq' by lia. + autorewrite with push_Zmul zsimplify. + rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc. + repeat autorewrite with pull_Zpow zsimplify in *. + reflexivity. + Qed. - Global Instance ripple_carry_is_add_with_carry {k} - {isdecode : is_decode decode} - {is_adc : is_add_with_carry adc} - : is_add_with_carry (ripple_carry_adc (k := k) adc). + Global Instance is_spread_left_from_shift + : is_spread_left_immediate (sprl_from_shift n). Proof. - destruct k as [|k]. - { constructor; simpl; intros; autorewrite with zsimplify; reflexivity. } - { induction k as [|k IHk]. - { cbv [ripple_carry_adc ripple_carry_tuple to_list]. - constructor; simpl @fst; simpl @snd; intros; - autorewrite with simpl_tuple_decoder; - push_decode; - autorewrite with zsimplify; reflexivity. } - { apply Build_is_add_with_carry'; intros x y c. - assert (0 <= n) by (destruct x; eauto using decode_exponent_nonnegative). - assert (2^n <> 0) by auto with zarith. - assert (0 <= S k * n) by nia. - rewrite !tuple_decoder_S, !ripple_carry_adc_SS by assumption. - simplify_projections; push_decode; generalize_decode. - erewrite carry_is_good_carry, carry_is_good_value by lia. - autorewrite with pull_Zpow push_Zof_nat zsimplify Zshift_to_pow. - split; apply f_equal2; nia. } } + apply is_spread_left_immediate_alt. + intros r count; intros. + pose proof (decode_range r). + assert (0 < 2^n) by auto with zarith. + assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia). + autorewrite with simpl_tuple_decoder. + destruct (Z_zerop count). + { subst; autorewrite with Zshift_to_pow zsimplify. + simpl; push_decode. + autorewrite with push_Zpow zsimplify. + reflexivity. } + simpl. + rewrite <- spread_left_from_shift_correct by lia. + autorewrite with zsimplify Zpow_to_shift. + reflexivity. Qed. + End spread_left. - End ripple_carry_adc. - - Section tuple2. - Section spread_left_correct. - Context {n W} {decode : decoder n W} {sprl : spread_left_immediate W} - {isdecode : is_decode decode}. - Lemma is_spread_left_immediate_alt - : is_spread_left_immediate sprl - <-> (forall r count, 0 <= count < n -> tuple_decoder (k := 2) (sprl r count) = (decode r << count) mod 2^(2*n)). - Proof. - split; intro H; [ | apply Build_is_spread_left_immediate' ]; - intros r count Hc; - [ | specialize (H r count Hc); revert H ]; - pose proof (decode_range r); - assert (0 < 2^n) by auto with zarith; - assert (0 <= 2^count < 2^n) by auto with zarith; - assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia); - autorewrite with simpl_tuple_decoder; push_decode; - autorewrite with Zshift_to_pow zsimplify push_Zpow. - { reflexivity. } - { intro H'; rewrite <- H'. - autorewrite with zsimplify; split; reflexivity. } - Qed. - End spread_left_correct. - - Local Arguments Z.pow !_ !_. - Local Arguments Z.mul !_ !_. + Local Opaque ripple_carry_adc. + Section full_from_half. + Context {W} + {mulhwll : multiply_low_low W} + {mulhwhl : multiply_high_low W} + {mulhwhh : multiply_high_high W} + {adc : add_with_carry W} + {shl : shift_left_immediate W} + {shr : shift_right_immediate W} + {half_n : Z} + {ldi : load_immediate W} + {decode : decoder (2 * half_n) W} + {ismulhwll : is_mul_low_low half_n mulhwll} + {ismulhwhl : is_mul_high_low half_n mulhwhl} + {ismulhwhh : is_mul_high_high half_n mulhwhh} + {isadc : is_add_with_carry adc} + {isshl : is_shift_left_immediate shl} + {isshr : is_shift_right_immediate shr} + {isldi : is_load_immediate ldi} + {isdecode : is_decode decode}. - Section spread_left. - Context (n : Z) {W} - {ldi : load_immediate W} - {shl : shift_left_immediate W} - {shr : shift_right_immediate W} - {decode : decoder n W} - {isdecode : is_decode decode} - {isldi : is_load_immediate ldi} - {isshl : is_shift_left_immediate shl} - {isshr : is_shift_right_immediate shr}. - - Lemma spread_left_from_shift_correct - r count - (H : 0 < count < n) - : (decode (shl r count) + decode (shr r (n - count)) << n = decode r << count mod 2^(2*n))%Z. - Proof. - assert (0 <= n - count < n) by lia. - assert (0 < 2^(n-count)) by auto with zarith. - assert (2^count < 2^n) by auto with zarith. - pose proof (decode_range r). - assert (0 <= decode r * 2 ^ count < 2 ^ n * 2^n) by auto with zarith. - simpl. - push_decode; autorewrite with Zshift_to_pow zsimplify. - replace (decode r / 2^(n-count) * 2^n)%Z with ((decode r / 2^(n-count) * 2^(n-count)) * 2^count)%Z - by (rewrite <- Z.mul_assoc; autorewrite with pull_Zpow zsimplify; reflexivity). - rewrite Z.mul_div_eq' by lia. - autorewrite with push_Zmul zsimplify. - rewrite <- Z.mul_mod_distr_r_full, Z.add_sub_assoc. - repeat autorewrite with pull_Zpow zsimplify in *. - reflexivity. - Qed. - - Global Instance is_spread_left_from_shift - : is_spread_left_immediate (sprl_from_shift n). - Proof. - apply is_spread_left_immediate_alt. - intros r count; intros. - pose proof (decode_range r). - assert (0 < 2^n) by auto with zarith. - assert (decode r < 2^n * 2^n) by (generalize dependent (decode r); intros; nia). - autorewrite with simpl_tuple_decoder. - destruct (Z_zerop count). - { subst; autorewrite with Zshift_to_pow zsimplify. - simpl; push_decode. - autorewrite with push_Zpow zsimplify. - reflexivity. } - simpl. - rewrite <- spread_left_from_shift_correct by lia. - autorewrite with zsimplify Zpow_to_shift. - reflexivity. - Qed. - End spread_left. - - Local Opaque ripple_carry_adc. - Section full_from_half. - Context {W} - {mulhwll : multiply_low_low W} - {mulhwhl : multiply_high_low W} - {mulhwhh : multiply_high_high W} - {adc : add_with_carry W} - {shl : shift_left_immediate W} - {shr : shift_right_immediate W} - {half_n : Z} - {ldi : load_immediate W} - {decode : decoder (2 * half_n) W} - {ismulhwll : is_mul_low_low half_n mulhwll} - {ismulhwhl : is_mul_high_low half_n mulhwhl} - {ismulhwhh : is_mul_high_high half_n mulhwhh} - {isadc : is_add_with_carry adc} - {isshl : is_shift_left_immediate shl} - {isshr : is_shift_right_immediate shr} - {isldi : is_load_immediate ldi} - {isdecode : is_decode decode}. - - Local Arguments Z.mul !_ !_. - Lemma spread_left_from_shift_half_correct - r - : (decode (shl r half_n) + decode (shr r half_n) * (2^half_n * 2^half_n) - = (decode r * 2^half_n) mod (2^half_n*2^half_n*2^half_n*2^half_n))%Z. - Proof. - destruct (0 (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) - (fun n (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) - using solve [ auto with zarith ] - : simpl_tuple_decoder. - Local Ltac t := - hnf; intros [??] [??]; - assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative; - assert (0 <= half_n) by omega; - simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll; - rewrite decode_mul_double_mod; push_decode; autorewrite with simpl_tuple_decoder; - simpl; - push_decode; generalize_decode_var; - autorewrite with Zshift_to_pow zsimplify; - autorewrite with push_Zpow in *; Z.rewrite_mod_small; - try reflexivity. - - Global Instance mul_double_is_multiply_low_low : is_mul_low_low (2 * half_n) mul_double_multiply_low_low. - Proof. t. Qed. - Global Instance mul_double_is_multiply_high_low : is_mul_high_low (2 * half_n) mul_double_multiply_high_low. - Proof. t. Qed. - Global Instance mul_double_is_multiply_high_high : is_mul_high_high (2 * half_n) mul_double_multiply_high_high. - Proof. t. Qed. - End full_from_half. - End tuple2. -End generic_constructions. + destruct (Z_zerop half_n). + { subst; simpl in *. + autorewrite with zsimplify. + nia. } + assert (half_n < 0) by lia. + assert (2^half_n = 0) by auto with zarith. + assert (0 < 0) by nia; omega. } + Qed. -Hint Resolve (fun n W decode pf => (@tuple_is_decode n W decode 2 pf : @is_decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2))) : typeclass_instances. -Hint Extern 3 (@is_decode _ (tuple ?W ?k) _) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode pf => (@tuple_is_decode n W decode k pf : @is_decode (kv * n) (tuple W k) (@tuple_decoder n W decode k : decode (kv * n) W))) : typeclass_instances. + Lemma decode_mul_double_mod x y + : (tuple_decoder (mul_double half_n x y) = (decode x * decode y) mod (2^(2 * half_n) * 2^(2*half_n)))%Z. + Proof. + assert (0 <= 2 * half_n) by eauto using decode_exponent_nonnegative. + assert (0 <= half_n) by omega. + unfold mul_double. + push_decode; autorewrite with simpl_tuple_decoder; simplify_projections. + autorewrite with zsimplify Zshift_to_pow push_Zpow. + rewrite !spread_left_from_shift_half_correct. + push_decode. + generalize_decode_var. + simpl in *. + autorewrite with push_Zpow in *. + repeat autorewrite with Zshift_to_pow zsimplify push_Zpow. + rewrite <- !(Z.mul_mod_distr_r_full _ _ (_^_ * _^_)), ?Z.mul_assoc. + Z.rewrite_mod_small. + push_Zmod; pull_Zmod. + apply f_equal2; [ | reflexivity ]. + Z.div_mod_to_quot_rem; nia. + Qed. -Hint Extern 2 (@is_add_with_carry _ (tuple ?W ?k) (@tuple_decoder ?n _ ?decode _) (@ripple_carry_adc _ ?adc _)) - => apply (@ripple_carry_is_add_with_carry n W decode adc k) : typeclass_instances. + Lemma decode_mul_double_function x y + : tuple_decoder (mul_double half_n x y) = (decode x * decode y)%Z. + Proof. + rewrite decode_mul_double_mod; generalize_decode_var. + simpl in *; Z.rewrite_mod_small; reflexivity. + Qed. -Hint Rewrite @tuple_decoder_S @tuple_decoder_O @tuple_decoder_m1 using solve [ auto with zarith ] : simpl_tuple_decoder. -Hint Rewrite - (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (tuple W 2) (@tuple_decoder n W decode 2) w = _)) - (fun n W (decode : decoder n W) w pf => (@tuple_decoder_S n W decode 0 w pf : @Interface.decode (2 * n) (W * W) (@tuple_decoder n W decode 2) w = _)) - using solve [ auto with zarith ] - : simpl_tuple_decoder. + Global Instance mul_double_is_multiply_double : is_mul_double mul_double_multiply. + Proof. + apply decode_mul_double_iff; apply decode_mul_double_function. + Qed. + End full_from_half. + + Section half_from_full. + Context {n W} + {decode : decoder n W} + {muldw : multiply_double W} + {isdecode : is_decode decode} + {ismuldw : is_mul_double muldw}. + + Local Ltac t := + hnf; intros [??] [??]; + assert (0 <= n) by eauto using decode_exponent_nonnegative; + assert (0 < 2^n) by auto with zarith; + assert (forall x y, 0 <= x < 2^n -> 0 <= y < 2^n -> 0 <= x * y < 2^n * 2^n) by auto with zarith; + simpl @Interface.mulhwhh; simpl @Interface.mulhwhl; simpl @Interface.mulhwll; + rewrite decode_mul_double; autorewrite with simpl_tuple_decoder Zshift_to_pow zsimplify push_Zpow; + Z.rewrite_mod_small; + try reflexivity. + + Global Instance mul_double_is_multiply_low_low : is_mul_low_low n mul_double_multiply_low_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_low : is_mul_high_low n mul_double_multiply_high_low. + Proof. t. Qed. + Global Instance mul_double_is_multiply_high_high : is_mul_high_high n mul_double_multiply_high_high. + Proof. t. Qed. + End half_from_full. +End tuple2. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 4139a91ce..00528a053 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -7,20 +7,6 @@ Require Import Crypto.Util.Notations. Local Open Scope Z_scope. Local Open Scope type_scope. -Create HintDb push_decode discriminated. -Create HintDb pull_decode discriminated. -Hint Extern 1 => progress autorewrite with push_decode in * : push_decode. -Hint Extern 1 => progress autorewrite with pull_decode in * : pull_decode. - -(* TODO(from jgross): Try dropping the record wrappers. See - https://github.com/mit-plv/fiat-crypto/pull/52#discussion_r74627992 - and - https://github.com/mit-plv/fiat-crypto/pull/52#discussion_r74658417 - and - https://github.com/mit-plv/fiat-crypto/pull/52#issuecomment-239536847. - The wrappers are here to make [autorewrite] databases feasable and - fast, based on design patterns learned from past experience. There - might be better ways. *) Class decoder (n : Z) W := { decode : W -> Z }. Coercion decode : decoder >-> Funclass. @@ -29,52 +15,108 @@ Global Arguments decode {n W _} _. Class is_decode {n W} (decode : decoder n W) := decode_range : forall x, 0 <= decode x < 2^n. +Class rewrite_eq {A} (x y : A) + := by_rewrite : x = y. +Arguments by_rewrite {A} _ _ {_}. + +Class rewrite_right_to_left_eq {A} (x y : A) + := by_rewrite_right_to_left : rewrite_eq x y. +Arguments by_rewrite_right_to_left {A} _ _ {_}. +Global Instance unfold_rewrite_right_to_left_eq {A x y} (H : @rewrite_eq A x y) + : @rewrite_right_to_left_eq A x y := H. + +Class rewrite_left_to_right_eq {A} (x y : A) + := by_rewrite_left_to_right : rewrite_eq x y. +Arguments by_rewrite_left_to_right {A} _ _ {_}. +Global Instance unfold_rewrite_left_to_right_eq {A x y} (H : @rewrite_eq A x y) + : @rewrite_left_to_right_eq A x y := H. + +Class bounded_in_range_cls (x y z : Z) := is_bounded_in_range : x <= y < z. +Ltac bounded_solver_tac := + solve [ eassumption | typeclasses eauto | omega ]. +Hint Extern 0 (bounded_in_range_cls _ _ _) => unfold bounded_in_range_cls; bounded_solver_tac : typeclass_instances. +Global Arguments bounded_in_range_cls / . +Global Instance decode_range_bound {n W} {decode : decoder n W} {H : is_decode decode} + : forall x, bounded_in_range_cls 0 (decode x) (2^n) + := H. + +Class bounded_le_cls (x y : Z) := is_bounded_le : x <= y. +Hint Extern 0 (bounded_le_cls _ _) => unfold bounded_le_cls; bounded_solver_tac : typeclass_instances. +Global Arguments bounded_le_cls / . + +Ltac push_decode_step := + match goal with + | [ |- context[@decode ?n ?W ?decoder ?w] ] + => let lem := constr:(by_rewrite_left_to_right (A := Z) (@decode n W decoder w) _) in + rewrite (lem : @decode n W decoder w = _) + | [ |- context[match @fst ?A ?B ?x with true => 1 | false => 0 end] ] + => let lem := constr:(by_rewrite_left_to_right (A := Z) (match @fst A B x with true => 1 | false => 0 end) _) in + rewrite (lem : _ = _) + end. +Ltac push_decode := repeat push_decode_step. +Ltac pull_decode_step := + match goal with + | [ |- context[?E] ] + => first [ let lem := constr:(by_rewrite_right_to_left (A := Z) _ E) in + rewrite <- (lem : _ = E) + | let lem := constr:(by_rewrite_right_to_left (A := bool) _ E) in + rewrite <- (lem : _ = E) ] + end. +Ltac pull_decode := repeat pull_decode_step. + +(** This is required for typeclass resolution to be fast. *) +Typeclasses Opaque decode. + Section InstructionGallery. Context (n : Z) (* bit-width of width of [W] *) {W : Type} (* bounded type, [W] for word *) (Wdecoder : decoder n W). + Local Infix "==" := rewrite_eq. + Local Infix "=~>" := rewrite_left_to_right_eq. + Local Infix "<~=" := rewrite_right_to_left_eq. + Local Notation "x <= y < z" := (bounded_in_range_cls x y z). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) Class load_immediate := { ldi : imm -> W }. Global Coercion ldi : load_immediate >-> Funclass. Class is_load_immediate {ldi : load_immediate} := - decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. + decode_load_immediate :> forall x, 0 <= x < 2^n -> decode (ldi x) =~> x. Class shift_right_doubleword_immediate := { shrd : W -> W -> imm -> W }. Global Coercion shrd : shift_right_doubleword_immediate >-> Funclass. Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := - decode_shift_right_doubleword : + decode_shift_right_doubleword :> forall high low count, 0 <= count < n - -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. + -> decode (shrd high low count) == (((decode high << n) + decode low) >> count) mod 2^n. Class shift_left_immediate := { shl : W -> imm -> W }. Global Coercion shl : shift_left_immediate >-> Funclass. Class is_shift_left_immediate (shl : shift_left_immediate) := - decode_shift_left_immediate : - forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. + decode_shift_left_immediate :> + forall r count, 0 <= count < n -> decode (shl r count) == (decode r << count) mod 2^n. Class shift_right_immediate := { shr : W -> imm -> W }. Global Coercion shr : shift_right_immediate >-> Funclass. Class is_shift_right_immediate (shr : shift_right_immediate) := - decode_shift_right_immediate : - forall r count, 0 <= count < n -> decode (shr r count) = (decode r >> count). + decode_shift_right_immediate :> + forall r count, 0 <= count < n -> decode (shr r count) == (decode r >> count). Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }. Global Coercion sprl : spread_left_immediate >-> Funclass. Class is_spread_left_immediate (sprl : spread_left_immediate) := { - decode_fst_spread_left_immediate : forall r count, + decode_fst_spread_left_immediate :> forall r count, 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) mod 2^n; - decode_snd_spread_left_immediate : forall r count, + -> decode (fst (sprl r count)) =~> (decode r << count) mod 2^n; + decode_snd_spread_left_immediate :> forall r count, 0 <= count < n - -> decode (snd (sprl r count)) = (decode r << count) >> n + -> decode (snd (sprl r count)) =~> (decode r << count) >> n }. @@ -89,8 +131,8 @@ Section InstructionGallery. Global Coercion mkl : mask_keep_low >-> Funclass. Class is_mask_keep_low (mkl : mask_keep_low) := - decode_mask_keep_low : forall r count, - 0 <= count < n -> decode (mkl r count) = decode r mod 2^count. + decode_mask_keep_low :> forall r count, + 0 <= count < n -> decode (mkl r count) == decode r mod 2^count. Local Notation bit b := (if b then 1 else 0). @@ -99,8 +141,8 @@ Section InstructionGallery. Class is_add_with_carry (adc : add_with_carry) := { - bit_fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; - decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) + bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) == (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) == (decode x + decode y + bit c) mod (2^n) }. Definition Build_is_add_with_carry' (adc : add_with_carry) @@ -113,8 +155,8 @@ Section InstructionGallery. Class is_sub_with_carry (subc:W->W->bool->bool*W) := { - fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) forall x y c, fst (subc x y c) == ((decode x - decode y - bit c) forall x y c, decode (snd (subc x y c)) == (decode x - decode y - bit c) mod 2^n }. Definition Build_is_sub_with_carry' (subc : sub_with_carry) @@ -126,7 +168,7 @@ Section InstructionGallery. Global Coercion mul : multiply >-> Funclass. Class is_mul (mul : multiply) := - decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n. + decode_mul :> forall x y, decode (mul x y) == (decode x * decode y)%Z. Class multiply_low_low := { mulhwll : W -> W -> W }. Global Coercion mulhwll : multiply_low_low >-> Funclass. @@ -138,20 +180,20 @@ Section InstructionGallery. Global Coercion muldw : multiply_double >-> Funclass. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := - decode_mul_low_low : - forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + decode_mul_low_low :> + forall x y, decode (mulhwll x y) == ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := - decode_mul_high_low : - forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + decode_mul_high_low :> + forall x y, decode (mulhwhl x y) == ((decode x >> w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := - decode_mul_high_high : - forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. + decode_mul_high_high :> + forall x y, decode (mulhwhh x y) == ((decode x >> w) * (decode y >> w)) mod 2^n. Class is_mul_double (muldw : multiply_double) := { - decode_fst_mul_double : - forall x y, decode (fst (muldw x y)) = (decode x * decode y) mod 2^n; - decode_snd_mul_double : - forall x y, decode (snd (muldw x y)) = (decode x * decode y) >> n + decode_fst_mul_double :> + forall x y, decode (fst (muldw x y)) =~> (decode x * decode y) mod 2^n; + decode_snd_mul_double :> + forall x y, decode (snd (muldw x y)) =~> (decode x * decode y) >> n }. Definition Build_is_mul_double' (muldw : multiply_double) (pf : forall x y, _ /\ _) @@ -163,17 +205,17 @@ Section InstructionGallery. Global Coercion selc : select_conditional >-> Funclass. Class is_select_conditional (selc : select_conditional) := - decode_select_conditional : forall b x y, - decode (selc b x y) = if b then decode x else decode y. + decode_select_conditional :> forall b x y, + decode (selc b x y) == if b then decode x else decode y. Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }. Global Coercion addm : add_modulo >-> Funclass. Class is_add_modulo (addm : add_modulo) := - decode_add_modulo : forall x y modulus, - decode (addm x y modulus) = (if (decode x + decode y) forall x y modulus, + decode (addm x y modulus) == (if (decode x + decode y) Z) - : @decode n W {| decode := dec |} = dec. -Proof. reflexivity. Qed. - -Lemma decode_if_bool n W (decode : decoder n W) (b : bool) x y - : decode (if b then x else y) - = if b then decode x else decode y. -Proof. destruct b; reflexivity. Qed. - -Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} - (isinhabited : W) - : 0 <= n. -Proof. - pose proof (decode_range isinhabited). - assert (0 < 2^n) by omega. - destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ]. - assert (2^n = 0) by auto using Z.pow_neg_r. - omega. -Qed. - -Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_shift_right_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_fst_mul_double @decode_snd_mul_double @decode_select_conditional @decode_add_modulo @decode_proj @decode_if_bool using bounded_solver_tac : push_decode. - -Ltac push_decode_step := - first [ rewrite !decode_proj - | rewrite !decode_if_bool - | erewrite !decode_load_immediate by bounded_solver_tac - | erewrite !decode_shift_right_doubleword by bounded_solver_tac - | erewrite !decode_shift_left_immediate by bounded_solver_tac - | erewrite !decode_shift_right_immediate by bounded_solver_tac - | erewrite !decode_fst_spread_left_immediate by bounded_solver_tac - | erewrite !decode_snd_spread_left_immediate by bounded_solver_tac - | erewrite !decode_mask_keep_low by bounded_solver_tac - | erewrite !bit_fst_add_with_carry by bounded_solver_tac - | erewrite !decode_snd_add_with_carry by bounded_solver_tac - | erewrite !fst_sub_with_carry by bounded_solver_tac - | erewrite !decode_snd_sub_with_carry by bounded_solver_tac - | erewrite !decode_mul by bounded_solver_tac - | erewrite !decode_mul_low_low by bounded_solver_tac - | erewrite !decode_mul_high_low by bounded_solver_tac - | erewrite !decode_mul_high_high by bounded_solver_tac - | erewrite !decode_fst_mul_double by bounded_solver_tac - | erewrite !decode_snd_mul_double by bounded_solver_tac - | erewrite !decode_select_conditional by bounded_solver_tac - | erewrite !decode_add_modulo by bounded_solver_tac ]. -Ltac pull_decode_step := - first [ erewrite <- !decode_load_immediate by bounded_solver_tac - | erewrite <- !decode_shift_right_doubleword by bounded_solver_tac - | erewrite <- !decode_shift_left_immediate by bounded_solver_tac - | erewrite <- !decode_shift_right_immediate by bounded_solver_tac - | erewrite <- !decode_fst_spread_left_immediate by bounded_solver_tac - | erewrite <- !decode_snd_spread_left_immediate by bounded_solver_tac - | erewrite <- !decode_mask_keep_low by bounded_solver_tac - | erewrite <- !bit_fst_add_with_carry by bounded_solver_tac - | erewrite <- !decode_snd_add_with_carry by bounded_solver_tac - | erewrite <- !fst_sub_with_carry by bounded_solver_tac - | erewrite <- !decode_snd_sub_with_carry by bounded_solver_tac - | erewrite <- !decode_mul by bounded_solver_tac - | erewrite <- !decode_mul_low_low by bounded_solver_tac - | erewrite <- !decode_mul_high_low by bounded_solver_tac - | erewrite <- !decode_mul_high_high by bounded_solver_tac - | erewrite <- !decode_fst_mul_double by bounded_solver_tac - | erewrite <- !decode_snd_mul_double by bounded_solver_tac - | erewrite <- !decode_select_conditional by bounded_solver_tac - | erewrite <- !decode_add_modulo by bounded_solver_tac ]. -Ltac push_decode := repeat push_decode_step. -Ltac pull_decode := repeat pull_decode_step. - -(* We take special care to handle the case where the decoder is - syntactically different but the decoded expression is judgmentally - the same; we don't want to split apart variables that should be the - same. *) -Ltac set_decode_step check := - match goal with - | [ |- context G[@decode ?n ?W ?dr ?w] ] - => check w; - first [ match goal with - | [ d := @decode _ _ _ w |- _ ] - => change (@decode n W dr w) with d - end - | generalize (@decode_range n W dr _ w); - let d := fresh "d" in - set (d := @decode n W dr w); - intro ] - end. -Ltac set_decode check := repeat set_decode_step check. -Ltac clearbody_decode := - repeat match goal with - | [ H := @decode _ _ _ _ |- _ ] => clearbody H - end. -Ltac generalize_decode_by check := set_decode check; clearbody_decode. -Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). -Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w). - Module fancy_machine. Local Notation imm := Z (only parsing). diff --git a/src/BoundedArithmetic/InterfaceProofs.v b/src/BoundedArithmetic/InterfaceProofs.v new file mode 100644 index 000000000..b8e20c607 --- /dev/null +++ b/src/BoundedArithmetic/InterfaceProofs.v @@ -0,0 +1,202 @@ +(** * Alternate forms for Interface for bounded arithmetic *) +Require Import Coq.ZArith.ZArith Coq.micromega.Psatz. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.Notations. + +Local Open Scope type_scope. +Local Open Scope Z_scope. + +Local Infix "==" := rewrite_eq. +Local Infix "=~>" := rewrite_left_to_right_eq. +Local Infix "<~=" := rewrite_right_to_left_eq. +Local Notation bit b := (if b then 1 else 0). + +Section InstructionGallery. + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). + Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + + Lemma is_spread_left_immediate_alt + {sprl : spread_left_immediate W} + {isdecode : is_decode Wdecoder} + : is_spread_left_immediate sprl + <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n)). + Proof. + split; intro H; [ | apply Build_is_spread_left_immediate' ]; + intros r count Hc; + [ | specialize (H r count Hc); revert H ]; + unfold bounded_in_range_cls in *; + pose proof (decode_range r); + assert (0 < 2^n) by auto with zarith; + assert (0 <= 2^count < 2^n) by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia); + rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate + by typeclasses eauto with typeclass_instances core; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. + + Lemma is_mul_double_alt + {muldw : multiply_double W} + {isdecode : is_decode Wdecoder} + : is_mul_double muldw + <-> (forall x y, decode (fst (muldw x y)) + decode (snd (muldw x y)) << n = (decode x * decode y) mod (2^n*2^n)). + Proof. + split; intro H; [ | apply Build_is_mul_double' ]; + intros x y; + [ | specialize (H x y); revert H ]; + pose proof (decode_range x); + pose proof (decode_range y); + assert (0 < 2^n) by auto with zarith; + assert (0 <= decode x * decode y < 2^n * 2^n) by nia; + (destruct (0 <=? n) eqn:?; Z.ltb_to_lt; + [ | assert (2^n = 0) by auto with zarith; exfalso; omega ]); + rewrite ?decode_fst_mul_double, ?decode_snd_mul_double + by typeclasses eauto with typeclass_instances core; + autorewrite with Zshift_to_pow zsimplify push_Zpow. + { reflexivity. } + { intro H'; rewrite <- H'. + autorewrite with zsimplify; split; reflexivity. } + Qed. +End InstructionGallery. + +Local Notation "x <= y < z" := (bounded_in_range_cls x y z). + +Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}. +Global Arguments is_mul_double_alt {_ _ _ _ _}. + +Ltac bounded_solver_tac := + solve [ eassumption | typeclasses eauto | omega ]. + +Global Instance decode_proj n W (dec : W -> Z) + : @decode n W {| decode := dec |} =~> dec. +Proof. reflexivity. Qed. + +Global Instance decode_if_bool n W (decode : decoder n W) (b : bool) x y + : decode (if b then x else y) + =~> if b then decode x else decode y. +Proof. destruct b; reflexivity. Qed. + +Global Instance decode_mod_small {n W} {decode : decoder n W} {x b} + {H : bounded_in_range_cls 0 (decode x) b} + : decode x <~= decode x mod b. +Proof. + Z.rewrite_mod_small; reflexivity. +Qed. + +Global Instance decode_mod_range {n W decode} {H : @is_decode n W decode} x + : decode x <~= decode x mod 2^n. +Proof. exact _. Qed. + +Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} + (isinhabited : W) + : 0 <= n. +Proof. + pose proof (decode_range isinhabited). + assert (0 < 2^n) by omega. + destruct (Z_lt_ge_dec n 0) as [H'|]; [ | omega ]. + assert (2^n = 0) by auto using Z.pow_neg_r. + omega. +Qed. + +Section adc_subc. + Context {n W} + {decode : decoder n W} + {adc : add_with_carry W} + {subc : sub_with_carry W} + {isdecode : is_decode decode} + {isadc : is_add_with_carry adc} + {issubc : is_sub_with_carry subc}. + Global Instance bit_fst_add_with_carry_false + : forall x y, bit (fst (adc x y false)) == (decode x + decode y) >> n. + Proof. + intros; erewrite bit_fst_add_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance bit_fst_add_with_carry_true + : forall x y, bit (fst (adc x y true)) == (decode x + decode y + 1) >> n. + Proof. + intros; erewrite bit_fst_add_with_carry by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_add_with_carry_leb + : forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)). + Proof. + intros x y c; hnf. + assert (0 <= n) by eauto using decode_exponent_nonnegative. + pose proof (decode_range x); pose proof (decode_range y). + assert (0 <= bit c <= 1) by (destruct c; omega). + lazymatch goal with + | [ |- fst ?x = (?a <=? ?b) :> bool ] + => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); + [ destruct (fst x), (a <=? b); intro; congruence | ] + end. + push_decode. + autorewrite with Zshift_to_pow. + rewrite Z.div_between_0_if by auto with zarith. + reflexivity. + Qed. + Global Instance fst_add_with_carry_false_leb + : forall x y, fst (adc x y false) <~= (2^n <=? (decode x + decode y)). + Proof. + intros; erewrite fst_add_with_carry_leb by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_add_with_carry_true_leb + : forall x y, fst (adc x y true) == (2^n <=? (decode x + decode y + 1)). + Proof. + intros; erewrite fst_add_with_carry_leb by assumption. + autorewrite with zsimplify_const; reflexivity. + Qed. + Global Instance fst_sub_with_carry_false + : forall x y, fst (subc x y false) == ((decode x - decode y) apply @fst_add_with_carry_false_leb : typeclass_instances. +Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) +=> apply @fst_add_with_carry_true_leb : typeclass_instances. +Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) +=> apply @fst_add_with_carry_leb : typeclass_instances. + + +(* We take special care to handle the case where the decoder is + syntactically different but the decoded expression is judgmentally + the same; we don't want to split apart variables that should be the + same. *) +Ltac set_decode_step check := + match goal with + | [ |- context G[@decode ?n ?W ?dr ?w] ] + => check w; + first [ match goal with + | [ d := @decode _ _ _ w |- _ ] + => change (@decode n W dr w) with d + end + | generalize (@decode_range n W dr _ w); + let d := fresh "d" in + set (d := @decode n W dr w); + intro ] + end. +Ltac set_decode check := repeat set_decode_step check. +Ltac clearbody_decode := + repeat match goal with + | [ H := @decode _ _ _ _ |- _ ] => clearbody H + end. +Ltac generalize_decode_by check := set_decode check; clearbody_decode. +Ltac generalize_decode := generalize_decode_by ltac:(fun w => idtac). +Ltac generalize_decode_var := generalize_decode_by ltac:(fun w => is_var w). -- cgit v1.2.3 From bf0d9280ebf806eef8ee3280f7976edb3282ae6e Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Thu, 25 Aug 2016 11:48:05 -0700 Subject: Integrate suggestions from Andres --- _CoqProject | 1 + src/BoundedArithmetic/ArchitectureToZLikeProofs.v | 1 - src/BoundedArithmetic/DoubleBounded.v | 11 --- src/BoundedArithmetic/DoubleBoundedProofs.v | 17 ++-- src/BoundedArithmetic/Interface.v | 109 +++++++++------------- src/BoundedArithmetic/InterfaceProofs.v | 59 ++++++++---- src/Util/AutoRewrite.v | 56 +++++++++++ src/Util/Notations.v | 1 + 8 files changed, 147 insertions(+), 108 deletions(-) create mode 100644 src/Util/AutoRewrite.v (limited to 'src') diff --git a/_CoqProject b/_CoqProject index 9ed483860..40804dcd9 100644 --- a/_CoqProject +++ b/_CoqProject @@ -74,6 +74,7 @@ src/Specific/GF25519.v src/Tactics/VerdiTactics.v src/Tactics/Algebra_syntax/Nsatz.v src/Util/AdditionChainExponentiation.v +src/Util/AutoRewrite.v src/Util/Bool.v src/Util/CaseUtil.v src/Util/Decidable.v diff --git a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v index 3060e17bb..804296374 100644 --- a/src/BoundedArithmetic/ArchitectureToZLikeProofs.v +++ b/src/BoundedArithmetic/ArchitectureToZLikeProofs.v @@ -93,7 +93,6 @@ Section fancy_machine_p256_montgomery_foundation. { abstract t. } { abstract t. } { abstract t. } -Hint Resolve Z.div_pos : zarith. { abstract t. } { abstract t. } { abstract t. } diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 55e46aa2b..b624c5082 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -7,7 +7,6 @@ Require Import Crypto.Util.Tuple. Require Import Crypto.Util.ListUtil. Require Import Crypto.Util.Notations. -Local Open Scope list_scope. Local Open Scope nat_scope. Local Open Scope Z_scope. Local Open Scope type_scope. @@ -23,16 +22,6 @@ Global Arguments tuple_decoder : simpl never. Hint Extern 3 (decoder _ (tuple ?W ?k)) => let kv := (eval simpl in (Z.of_nat k)) in apply (fun n decode => (@tuple_decoder n W decode k : decoder (kv * n) (tuple W k))) : typeclass_instances. Section ripple_carry_definitions. - Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) - (xs ys : list T) (carry : bool) : bool * list T - := List.fold_right - (fun x_y carry_zs => let '(x, y) := eta x_y in - let '(carry, zs) := eta carry_zs in - let '(carry, z) := eta (f x y carry) in - (carry, z :: zs)) - (carry, nil) - (List.combine xs ys). - (** tuple is high to low ([to_list] reverses) *) Fixpoint ripple_carry_tuple' {T} (f : T -> T -> bool -> bool * T) k : forall (xs ys : tuple' T k) (carry : bool), bool * tuple' T k diff --git a/src/BoundedArithmetic/DoubleBoundedProofs.v b/src/BoundedArithmetic/DoubleBoundedProofs.v index b69232076..53ac59d00 100644 --- a/src/BoundedArithmetic/DoubleBoundedProofs.v +++ b/src/BoundedArithmetic/DoubleBoundedProofs.v @@ -16,16 +16,15 @@ Require Import Crypto.Util.Notations. Import ListNotations. Local Open Scope list_scope. Local Open Scope nat_scope. -Local Open Scope Z_scope. Local Open Scope type_scope. +Local Open Scope Z_scope. Local Coercion Z.of_nat : nat >-> Z. Local Coercion Pos.to_nat : positive >-> nat. Local Notation eta x := (fst x, snd x). -Local Infix "==" := rewrite_eq. -Local Infix "=~>" := rewrite_left_to_right_eq. -Local Infix "<~=" := rewrite_right_to_left_eq. +Import BoundedRewriteNotations. +Local Open Scope Z_scope. Section decode. Context {n W} {decode : decoder n W}. @@ -99,12 +98,12 @@ Section decode. Global Instance tuple_decoder_m1 w : tuple_decoder (k := 0) w =~> 0. Proof. reflexivity. Qed. - Global Instance tuple_decoder_2' w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. + Global Instance tuple_decoder_2' w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << (1%nat * n))%Z. Proof. intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. reflexivity. Qed. - Global Instance tuple_decoder_2 w : bounded_le_cls 0 n -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. + Global Instance tuple_decoder_2 w : (0 <= n)%bounded_rewrite -> tuple_decoder (k := 2) w <~= (decode (fst w) + decode (snd w) << n)%Z. Proof. intros; rewrite !tuple_decoder_S, !tuple_decoder_O by assumption. autorewrite with zsimplify_const; reflexivity. @@ -205,8 +204,8 @@ Global Instance decode_is_spread_left_immediate {isdecode : is_decode decode} {issprl : is_spread_left_immediate sprl} : forall r count, - 0 <= count < n - -> tuple_decoder (sprl r count) == decode r << count + (0 <= count < n)%bounded_rewrite + -> tuple_decoder (sprl r count) <~=~> decode r << count := proj1 decode_is_spread_left_immediate_iff _. Lemma decode_mul_double_iff @@ -233,7 +232,7 @@ Global Instance decode_mul_double {muldw : multiply_double W} {isdecode : is_decode decode} {ismuldw : is_mul_double muldw} - : forall x y, tuple_decoder (muldw x y) == (decode x * decode y)%Z + : forall x y, tuple_decoder (muldw x y) <~=~> (decode x * decode y)%Z := proj1 decode_mul_double_iff _. Lemma ripple_carry_tuple_SS {T} f k xss yss carry diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index 00528a053..152c43cee 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -2,10 +2,11 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Tuple. +Require Import Crypto.Util.AutoRewrite. Require Import Crypto.Util.Notations. -Local Open Scope Z_scope. Local Open Scope type_scope. +Local Open Scope Z_scope. Class decoder (n : Z) W := { decode : W -> Z }. @@ -15,22 +16,6 @@ Global Arguments decode {n W _} _. Class is_decode {n W} (decode : decoder n W) := decode_range : forall x, 0 <= decode x < 2^n. -Class rewrite_eq {A} (x y : A) - := by_rewrite : x = y. -Arguments by_rewrite {A} _ _ {_}. - -Class rewrite_right_to_left_eq {A} (x y : A) - := by_rewrite_right_to_left : rewrite_eq x y. -Arguments by_rewrite_right_to_left {A} _ _ {_}. -Global Instance unfold_rewrite_right_to_left_eq {A x y} (H : @rewrite_eq A x y) - : @rewrite_right_to_left_eq A x y := H. - -Class rewrite_left_to_right_eq {A} (x y : A) - := by_rewrite_left_to_right : rewrite_eq x y. -Arguments by_rewrite_left_to_right {A} _ _ {_}. -Global Instance unfold_rewrite_left_to_right_eq {A x y} (H : @rewrite_eq A x y) - : @rewrite_left_to_right_eq A x y := H. - Class bounded_in_range_cls (x y z : Z) := is_bounded_in_range : x <= y < z. Ltac bounded_solver_tac := solve [ eassumption | typeclasses eauto | omega ]. @@ -44,26 +29,42 @@ Class bounded_le_cls (x y : Z) := is_bounded_le : x <= y. Hint Extern 0 (bounded_le_cls _ _) => unfold bounded_le_cls; bounded_solver_tac : typeclass_instances. Global Arguments bounded_le_cls / . +Inductive bounded_decode_pusher_tag := decode_tag. + Ltac push_decode_step := match goal with | [ |- context[@decode ?n ?W ?decoder ?w] ] - => let lem := constr:(by_rewrite_left_to_right (A := Z) (@decode n W decoder w) _) in - rewrite (lem : @decode n W decoder w = _) + => tc_rewrite (decode_tag) (@decode n W decoder w) -> | [ |- context[match @fst ?A ?B ?x with true => 1 | false => 0 end] ] - => let lem := constr:(by_rewrite_left_to_right (A := Z) (match @fst A B x with true => 1 | false => 0 end) _) in - rewrite (lem : _ = _) + => tc_rewrite (decode_tag) (match @fst A B x with true => 1 | false => 0 end) -> end. Ltac push_decode := repeat push_decode_step. Ltac pull_decode_step := match goal with | [ |- context[?E] ] - => first [ let lem := constr:(by_rewrite_right_to_left (A := Z) _ E) in - rewrite <- (lem : _ = E) - | let lem := constr:(by_rewrite_right_to_left (A := bool) _ E) in - rewrite <- (lem : _ = E) ] + => lazymatch type of E with + | Z => idtac + | bool => idtac + end; + tc_rewrite (decode_tag) <- E end. Ltac pull_decode := repeat pull_decode_step. +Delimit Scope bounded_rewrite_scope with bounded_rewrite. + +Infix "<~=~>" := (rewrite_eq decode_tag) : bounded_rewrite_scope. +Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : bounded_rewrite_scope. +Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : bounded_rewrite_scope. +Notation "x <= y" := (bounded_le_cls x y) : bounded_rewrite_scope. +Notation "x <= y < z" := (bounded_in_range_cls x y z) : bounded_rewrite_scope. + +Module Import BoundedRewriteNotations. + Infix "<~=~>" := (rewrite_eq decode_tag) : type_scope. + Infix "=~>" := (rewrite_left_to_right_eq decode_tag) : type_scope. + Infix "<~=" := (rewrite_right_to_left_eq decode_tag) : type_scope. + Open Scope bounded_rewrite_scope. +End BoundedRewriteNotations. + (** This is required for typeclass resolution to be fast. *) Typeclasses Opaque decode. @@ -71,10 +72,6 @@ Section InstructionGallery. Context (n : Z) (* bit-width of width of [W] *) {W : Type} (* bounded type, [W] for word *) (Wdecoder : decoder n W). - Local Infix "==" := rewrite_eq. - Local Infix "=~>" := rewrite_left_to_right_eq. - Local Infix "<~=" := rewrite_right_to_left_eq. - Local Notation "x <= y < z" := (bounded_in_range_cls x y z). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) Class load_immediate := { ldi : imm -> W }. @@ -90,21 +87,21 @@ Section InstructionGallery. decode_shift_right_doubleword :> forall high low count, 0 <= count < n - -> decode (shrd high low count) == (((decode high << n) + decode low) >> count) mod 2^n. + -> decode (shrd high low count) <~=~> (((decode high << n) + decode low) >> count) mod 2^n. Class shift_left_immediate := { shl : W -> imm -> W }. Global Coercion shl : shift_left_immediate >-> Funclass. Class is_shift_left_immediate (shl : shift_left_immediate) := decode_shift_left_immediate :> - forall r count, 0 <= count < n -> decode (shl r count) == (decode r << count) mod 2^n. + forall r count, 0 <= count < n -> decode (shl r count) <~=~> (decode r << count) mod 2^n. Class shift_right_immediate := { shr : W -> imm -> W }. Global Coercion shr : shift_right_immediate >-> Funclass. Class is_shift_right_immediate (shr : shift_right_immediate) := decode_shift_right_immediate :> - forall r count, 0 <= count < n -> decode (shr r count) == (decode r >> count). + forall r count, 0 <= count < n -> decode (shr r count) <~=~> (decode r >> count). Class spread_left_immediate := { sprl : W -> imm -> tuple W 2 (* [(low, high)] *) }. Global Coercion sprl : spread_left_immediate >-> Funclass. @@ -120,19 +117,12 @@ Section InstructionGallery. }. - Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate) - (pf : forall r count, 0 <= count < n - -> decode (fst (sprl r count)) = (decode r << count) mod 2^n - /\ decode (snd (sprl r count)) = (decode r << count) >> n) - := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); - decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. - Class mask_keep_low := { mkl :> W -> imm -> W }. Global Coercion mkl : mask_keep_low >-> Funclass. Class is_mask_keep_low (mkl : mask_keep_low) := decode_mask_keep_low :> forall r count, - 0 <= count < n -> decode (mkl r count) == decode r mod 2^count. + 0 <= count < n -> decode (mkl r count) <~=~> decode r mod 2^count. Local Notation bit b := (if b then 1 else 0). @@ -141,34 +131,24 @@ Section InstructionGallery. Class is_add_with_carry (adc : add_with_carry) := { - bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) == (decode x + decode y + bit c) >> n; - decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) == (decode x + decode y + bit c) mod (2^n) + bit_fst_add_with_carry :> forall x y c, bit (fst (adc x y c)) <~=~> (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry :> forall x y c, decode (snd (adc x y c)) <~=~> (decode x + decode y + bit c) mod (2^n) }. - Definition Build_is_add_with_carry' (adc : add_with_carry) - (pf : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n /\ decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)) - := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); - decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. - Class sub_with_carry := { subc : W -> W -> bool -> bool * W }. Global Coercion subc : sub_with_carry >-> Funclass. Class is_sub_with_carry (subc:W->W->bool->bool*W) := { - fst_sub_with_carry :> forall x y c, fst (subc x y c) == ((decode x - decode y - bit c) forall x y c, decode (snd (subc x y c)) == (decode x - decode y - bit c) mod 2^n + fst_sub_with_carry :> forall x y c, fst (subc x y c) <~=~> ((decode x - decode y - bit c) forall x y c, decode (snd (subc x y c)) <~=~> (decode x - decode y - bit c) mod 2^n }. - Definition Build_is_sub_with_carry' (subc : sub_with_carry) - (pf : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) W -> W }. Global Coercion mul : multiply >-> Funclass. Class is_mul (mul : multiply) := - decode_mul :> forall x y, decode (mul x y) == (decode x * decode y)%Z. + decode_mul :> forall x y, decode (mul x y) <~=~> (decode x * decode y). Class multiply_low_low := { mulhwll : W -> W -> W }. Global Coercion mulhwll : multiply_low_low >-> Funclass. @@ -181,13 +161,13 @@ Section InstructionGallery. Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := decode_mul_low_low :> - forall x y, decode (mulhwll x y) == ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. + forall x y, decode (mulhwll x y) <~=~> ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := decode_mul_high_low :> - forall x y, decode (mulhwhl x y) == ((decode x >> w) * (decode y mod 2^w)) mod 2^n. + forall x y, decode (mulhwhl x y) <~=~> ((decode x >> w) * (decode y mod 2^w)) mod 2^n. Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := decode_mul_high_high :> - forall x y, decode (mulhwhh x y) == ((decode x >> w) * (decode y >> w)) mod 2^n. + forall x y, decode (mulhwhh x y) <~=~> ((decode x >> w) * (decode y >> w)) mod 2^n. Class is_mul_double (muldw : multiply_double) := { decode_fst_mul_double :> @@ -195,27 +175,22 @@ Section InstructionGallery. decode_snd_mul_double :> forall x y, decode (snd (muldw x y)) =~> (decode x * decode y) >> n }. - Definition Build_is_mul_double' (muldw : multiply_double) - (pf : forall x y, _ /\ _) - := {| decode_fst_mul_double x y := proj1 (pf x y); - decode_snd_mul_double x y := proj2 (pf x y) |}. - Class select_conditional := { selc : bool -> W -> W -> W }. Global Coercion selc : select_conditional >-> Funclass. Class is_select_conditional (selc : select_conditional) := decode_select_conditional :> forall b x y, - decode (selc b x y) == if b then decode x else decode y. + decode (selc b x y) <~=~> if b then decode x else decode y. Class add_modulo := { addm : W -> W -> W (* modulus *) -> W }. Global Coercion addm : add_modulo >-> Funclass. Class is_add_modulo (addm : add_modulo) := decode_add_modulo :> forall x y modulus, - decode (addm x y modulus) == (if (decode x + decode y) (if (decode x + decode y) " := rewrite_left_to_right_eq. -Local Infix "<~=" := rewrite_right_to_left_eq. +Import BoundedRewriteNotations. Local Notation bit b := (if b then 1 else 0). Section InstructionGallery. @@ -19,11 +18,33 @@ Section InstructionGallery. (Wdecoder : decoder n W). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) + Definition Build_is_spread_left_immediate' (sprl : spread_left_immediate W) + (pf : forall r count, 0 <= count < n + -> _ /\ _) + := {| decode_fst_spread_left_immediate r count H := proj1 (pf r count H); + decode_snd_spread_left_immediate r count H := proj2 (pf r count H) |}. + + Definition Build_is_add_with_carry' (adc : add_with_carry W) + (pf : forall x y c, _ /\ _) + := {| bit_fst_add_with_carry x y c := proj1 (pf x y c); + decode_snd_add_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_sub_with_carry' (subc : sub_with_carry W) + (pf : forall x y c, _ /\ _) + : is_sub_with_carry subc + := {| fst_sub_with_carry x y c := proj1 (pf x y c); + decode_snd_sub_with_carry x y c := proj2 (pf x y c) |}. + + Definition Build_is_mul_double' (muldw : multiply_double W) + (pf : forall x y, _ /\ _) + := {| decode_fst_mul_double x y := proj1 (pf x y); + decode_snd_mul_double x y := proj2 (pf x y) |}. + Lemma is_spread_left_immediate_alt {sprl : spread_left_immediate W} {isdecode : is_decode Wdecoder} : is_spread_left_immediate sprl - <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n)). + <-> (forall r count, 0 <= count < n -> decode (fst (sprl r count)) + decode (snd (sprl r count)) << n = (decode r << count) mod (2^n*2^n))%Z. Proof. split; intro H; [ | apply Build_is_spread_left_immediate' ]; intros r count Hc; @@ -31,8 +52,8 @@ Section InstructionGallery. unfold bounded_in_range_cls in *; pose proof (decode_range r); assert (0 < 2^n) by auto with zarith; - assert (0 <= 2^count < 2^n) by auto with zarith; - assert (0 <= decode r * 2^count < 2^n * 2^n) by (generalize dependent (decode r); intros; nia); + assert (0 <= 2^count < 2^n)%Z by auto with zarith; + assert (0 <= decode r * 2^count < 2^n * 2^n)%Z by (generalize dependent (decode r); intros; nia); rewrite ?decode_fst_spread_left_immediate, ?decode_snd_spread_left_immediate by typeclasses eauto with typeclass_instances core; autorewrite with Zshift_to_pow zsimplify push_Zpow. @@ -53,7 +74,7 @@ Section InstructionGallery. pose proof (decode_range x); pose proof (decode_range y); assert (0 < 2^n) by auto with zarith; - assert (0 <= decode x * decode y < 2^n * 2^n) by nia; + assert (0 <= decode x * decode y < 2^n * 2^n)%Z by nia; (destruct (0 <=? n) eqn:?; Z.ltb_to_lt; [ | assert (2^n = 0) by auto with zarith; exfalso; omega ]); rewrite ?decode_fst_mul_double, ?decode_snd_mul_double @@ -65,8 +86,6 @@ Section InstructionGallery. Qed. End InstructionGallery. -Local Notation "x <= y < z" := (bounded_in_range_cls x y z). - Global Arguments is_spread_left_immediate_alt {_ _ _ _ _}. Global Arguments is_mul_double_alt {_ _ _ _ _}. @@ -95,7 +114,7 @@ Proof. exact _. Qed. Lemma decode_exponent_nonnegative {n W} (decode : decoder n W) {isdecode : is_decode decode} (isinhabited : W) - : 0 <= n. + : (0 <= n)%Z. Proof. pose proof (decode_range isinhabited). assert (0 < 2^n) by omega. @@ -113,13 +132,13 @@ Section adc_subc. {isadc : is_add_with_carry adc} {issubc : is_sub_with_carry subc}. Global Instance bit_fst_add_with_carry_false - : forall x y, bit (fst (adc x y false)) == (decode x + decode y) >> n. + : forall x y, bit (fst (adc x y false)) <~=~> (decode x + decode y) >> n. Proof. intros; erewrite bit_fst_add_with_carry by assumption. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance bit_fst_add_with_carry_true - : forall x y, bit (fst (adc x y true)) == (decode x + decode y + 1) >> n. + : forall x y, bit (fst (adc x y true)) <~=~> (decode x + decode y + 1) >> n. Proof. intros; erewrite bit_fst_add_with_carry by assumption. autorewrite with zsimplify_const; reflexivity. @@ -128,9 +147,9 @@ Section adc_subc. : forall x y c, fst (adc x y c) <~= (2^n <=? (decode x + decode y + bit c)). Proof. intros x y c; hnf. - assert (0 <= n) by eauto using decode_exponent_nonnegative. + assert (0 <= n)%Z by eauto using decode_exponent_nonnegative. pose proof (decode_range x); pose proof (decode_range y). - assert (0 <= bit c <= 1) by (destruct c; omega). + assert (0 <= bit c <= 1)%Z by (destruct c; omega). lazymatch goal with | [ |- fst ?x = (?a <=? ?b) :> bool ] => cut (((if fst x then 1 else 0) = (if a <=? b then 1 else 0))%Z); @@ -148,30 +167,30 @@ Section adc_subc. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance fst_add_with_carry_true_leb - : forall x y, fst (adc x y true) == (2^n <=? (decode x + decode y + 1)). + : forall x y, fst (adc x y true) <~=~> (2^n <=? (decode x + decode y + 1)). Proof. intros; erewrite fst_add_with_carry_leb by assumption. autorewrite with zsimplify_const; reflexivity. Qed. Global Instance fst_sub_with_carry_false - : forall x y, fst (subc x y false) == ((decode x - decode y) ((decode x - decode y) ((decode x - decode y - 1) apply @fst_add_with_carry_false_leb : typeclass_instances. -Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + 1))) => apply @fst_add_with_carry_true_leb : typeclass_instances. -Hint Extern 2 (rewrite_right_to_left_eq _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) +Hint Extern 2 (rewrite_right_to_left_eq decode_tag _ (_ <=? (@decode ?n ?W ?decoder ?x + @decode _ _ _ ?y + if ?c then _ else _))) => apply @fst_add_with_carry_leb : typeclass_instances. diff --git a/src/Util/AutoRewrite.v b/src/Util/AutoRewrite.v new file mode 100644 index 000000000..b5e276f20 --- /dev/null +++ b/src/Util/AutoRewrite.v @@ -0,0 +1,56 @@ +(** * Machinery for reimplementing some bits of [rewrite_strat] with open pattern databases *) +Require Import Crypto.Util.Notations. +(** We build classes for rewriting in each direction, and pick lemmas + by resolving on tags and the term to be rewritten. *) +(** Base class, for bidirectional rewriting. *) +Class rewrite_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite : x = y. +Arguments by_rewrite {tagT} tag {A} _ _ {_}. +Infix "<~=~>" := (rewrite_eq _) : type_scope. + +Class rewrite_right_to_left_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite_right_to_left : rewrite_eq tag x y. +Arguments by_rewrite_right_to_left {tagT} tag {A} _ _ {_}. +Global Instance unfold_rewrite_right_to_left_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y) + : @rewrite_right_to_left_eq tagT tag A x y := H. +Infix "<~=" := (rewrite_right_to_left_eq _) : type_scope. + +Class rewrite_left_to_right_eq {tagT} (tag : tagT) {A} (x y : A) + := by_rewrite_left_to_right : rewrite_eq tag x y. +Arguments by_rewrite_left_to_right {tagT} tag {A} _ _ {_}. +Global Instance unfold_rewrite_left_to_right_eq {tagT tag A x y} (H : @rewrite_eq tagT tag A x y) + : @rewrite_left_to_right_eq tagT tag A x y := H. +Infix "=~>" := (rewrite_left_to_right_eq _) : type_scope. + +Ltac typeclass_do_left_to_right tag from tac := + let lem := constr:(by_rewrite_left_to_right tag from _ : from = _) in tac lem. +Ltac typeclass_do_right_to_left tag from tac := + let lem := constr:(by_rewrite_right_to_left tag _ from : _ = from) in tac lem. + + +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" hyp_list(H) "|-" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in H |- *). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" "*" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" open_constr(from) "->" "in" "*" "|-" := + typeclass_do_left_to_right tag from ltac:(fun lem => rewrite -> lem in * |- ). + + +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" hyp_list(H) "|-" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in H |- *). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" "*" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- * ). +Tactic Notation "tc_rewrite" "(" open_constr(tag) ")" "<-" open_constr(from) "in" "*" "|-" := + typeclass_do_right_to_left tag from ltac:(fun lem => rewrite <- lem in * |- ). diff --git a/src/Util/Notations.v b/src/Util/Notations.v index 686443829..8795da399 100644 --- a/src/Util/Notations.v +++ b/src/Util/Notations.v @@ -27,6 +27,7 @@ Reserved Infix "~=" (at level 70). Reserved Infix "==" (at level 70, no associativity). Reserved Infix "=~>" (at level 70, no associativity). Reserved Infix "<~=" (at level 70, no associativity). +Reserved Infix "<~=~>" (at level 70, no associativity). Reserved Infix "≡" (at level 70, no associativity). Reserved Infix "≢" (at level 70, no associativity). Reserved Infix "≡_n" (at level 70, no associativity). -- cgit v1.2.3