From b5b1eebe2845b0e69d669b51cea9eeb67ea726e3 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Thu, 11 Aug 2016 15:41:00 -0700 Subject: Initial work on an architecture interface for ℤ/nℤ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This provides a cleaner interface for the bottom level implementation, as well as an implementation of multiplying 128x128 -> 256. --- src/BoundedArithmetic/ArchitectureToZLike.v | 38 ++++++++ src/BoundedArithmetic/DoubleBounded.v | 114 ++++++++++++++++++++++++ src/BoundedArithmetic/Interface.v | 131 ++++++++++++++++++++++++++++ 3 files changed, 283 insertions(+) create mode 100644 src/BoundedArithmetic/ArchitectureToZLike.v create mode 100644 src/BoundedArithmetic/DoubleBounded.v create mode 100644 src/BoundedArithmetic/Interface.v (limited to 'src') diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v new file mode 100644 index 000000000..6c92f342f --- /dev/null +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -0,0 +1,38 @@ +(*** Implementing ℤ-Like via Architecture *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BoundedArithmetic.DoubleBounded. +Require Import Crypto.ModularArithmetic.ZBounded. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. + +Local Existing Instances DoubleArchitectureBoundedOps DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps. + +Section ops. + Context {n_over_two : size}. + Local Notation n := (2 * n_over_two)%nat. + Context (ops : ArchitectureBoundedOps n) (mops : ArchitectureBoundedHalfWidthMulOps n) + (modulus : Z). + + Axiom admit : forall {T}, T. + + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : size) + : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := + { LargeT := @BoundedType (2 * n)%nat _; + SmallT := @BoundedType n _; + modulus_digits := encode modulus; + decode_large := decode; + decode_small := decode; + Mod_SmallBound v := snd v; + DivBy_SmallBound v := fst v; + DivBy_SmallerBound v := ShiftRight smaller_bound_exp v; + Mul x y := @Interface.Mul n _ _ x y; + CarryAdd x y := Interface.CarryAdd false x y; + CarrySubSmall x y := Interface.CarrySub false x y; + ConditionalSubtract b x := admit; + ConditionalSubtractModulus y := admit }. +End ops. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v new file mode 100644 index 000000000..59d961d4a --- /dev/null +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -0,0 +1,114 @@ +(*** Implementing Large Bounded Arithmetic via pairs *) +Require Import Coq.ZArith.ZArith Coq.Lists.List. +Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. + +Import ListNotations. +Local Open Scope list_scope. +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Local Coercion Z.of_nat : nat >-> Z. +Local Notation eta x := (fst x, snd x). + +Section with_single_ops. + Section single. + Context (n : size) {base_ops : ArchitectureBoundedOps n}. + + Definition ripple_carry {T} (f : bool -> T -> T -> bool * T) + (carry : bool) (xs ys : list T) : bool * list T + := List.fold_right + (fun x_y carry_zs => let '(x, y) := eta x_y in + let '(carry, zs) := eta carry_zs in + let '(carry, z) := eta (f carry x y) in + (carry, z :: zs)) + (carry, nil) + (List.combine xs ys). + + Definition two_list_to_tuple {A B} (x : A * list B) + := match x return match x with + | (a, [b0; b1]) => A * (B * B) + | _ => True + end + with + | (a, [b0; b1]) => (a, (b0, b1)) + | _ => I + end. + + Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat + := { BoundedType := BoundedType * BoundedType (* [(high, low)] *); + decode high_low := (decode (fst high_low) * 2^n + decode (snd high_low))%Z; + encode z := (encode (z / 2^n), encode (z mod 2^n)); + ShiftRight a high_low + := let '(high, low) := eta high_low in + if n <=? a then + (ShiftRight (a - n)%nat (encode 0, fst high), ShiftRight (a - n)%nat high) + else + (ShiftRight a (snd high, fst low), ShiftRight a low); + ShiftLeft a high_low + := let '(high, low) := eta high_low in + if 2 * n <=? a then + let '(high0, low) := eta (ShiftLeft (a - 2 * n)%nat low) in + let '(high_high, high1) := eta (ShiftLeft (a - 2 * n)%nat high) in + ((snd (CarryAdd false high0 high1), low), (encode 0, encode 0)) + else if n <=? a then + let '(high0, low) := eta (ShiftLeft (a - n)%nat low) in + let '(high_high, high1) := eta (ShiftLeft (a - n)%nat high) in + ((high_high, snd (CarryAdd false high0 high1)), (low, encode 0)) + else + let '(high0, low) := eta (ShiftLeft a low) in + let '(high_high, high1) := eta (ShiftLeft a high) in + ((encode 0, high_high), (snd (CarryAdd false high0 high1), low)); + Mod2Pow a high_low + := let '(high, low) := (fst high_low, snd high_low) in + (Mod2Pow (a - n)%nat high, Mod2Pow a low); + CarryAdd carry x_high_low y_high_low + := let '(xhigh, xlow) := eta x_high_low in + let '(yhigh, ylow) := eta y_high_low in + two_list_to_tuple (ripple_carry CarryAdd carry [xhigh; xlow] [yhigh; ylow]); + CarrySub carry x_high_low y_high_low + := let '(xhigh, xlow) := eta x_high_low in + let '(yhigh, ylow) := eta y_high_low in + two_list_to_tuple (ripple_carry CarrySub carry [xhigh; xlow] [yhigh; ylow]) }. + + Definition BoundedOfHalfBounded (x : @BoundedHalfType (2 * n)%nat _) : @BoundedType n _ + := match x with + | UpperHalf x => fst x + | LowerHalf x => snd x + end. + + Local Instance DoubleArchitectureBoundedHalfWidthMulOpsOfFullMulOps + {base_mops : ArchitectureBoundedFullMulOps n} + : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat := + { HalfWidthMul a b + := Mul (BoundedOfHalfBounded a) (BoundedOfHalfBounded b) }. + End single. + + Local Existing Instance DoubleArchitectureBoundedOps. + + Section full_from_half. + Context (n : size) {base_ops : ArchitectureBoundedOps (2 * n)%nat}. + + Local Infix "*" := HalfWidthMul. + + Local Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps + {base_mops : ArchitectureBoundedHalfWidthMulOps (2 * n)%nat} + : ArchitectureBoundedFullMulOps (2 * n)%nat := + { Mul a b + := let '(a1, a0) := (UpperHalf a, LowerHalf a) in + let '(b1, b0) := (UpperHalf b, LowerHalf b) in + let out := a0 * b0 in + let outHigh := a1 * b1 in + let tmp := a1 * b0 in + let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in + let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in + let tmp := a0 * b1 in + let '(carry, out) := eta (CarryAdd false out (snd (ShiftLeft n tmp))) in + let '(_, outHigh) := eta (CarryAdd carry outHigh (ShiftRight n (encode 0, tmp))) in + (outHigh, out) }. + End full_from_half. + + Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. +End with_single_ops. diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v new file mode 100644 index 000000000..e9ade7ad4 --- /dev/null +++ b/src/BoundedArithmetic/Interface.v @@ -0,0 +1,131 @@ +(*** Interface for bounded arithmetic *) +Require Import Coq.ZArith.ZArith. +Require Import Crypto.Util.ZUtil. +Require Import Crypto.Util.Notations. + +Local Open Scope nat_scope. +Local Open Scope Z_scope. +Local Open Scope type_scope. + +Definition size := nat. + +Local Coercion Z.of_nat : nat >-> Z. + +Class ArchitectureBoundedOps (n : size) := + { BoundedType : Type (* [n]-bit word *); + decode : BoundedType -> Z; + encode : Z -> BoundedType; + ShiftRight : forall a : size, BoundedType * BoundedType -> BoundedType; + (** given [(high, low)], constructs [(high << (n - a)) + (low >> + a)], i.e., shifts [high * 2ⁿ + low] down by [a] bits *) + ShiftLeft : forall a : size, BoundedType -> BoundedType * BoundedType; + (** given [x], constructs [(((x << a) / 2ⁿ) mod 2ⁿ, (x << a) mod + 2ⁿ], i.e., shifts [x] up by [a] bits, and takes the low [2n] + bits of the result *) + Mod2Pow : forall a : size, BoundedType -> BoundedType (* [mod 2ᵃ] *); + CarryAdd : forall (carry : bool) (x y : BoundedType), bool * BoundedType; + (** Ouputs [(x + y + if carry then 1 else 0) mod 2ⁿ], together + with a boolean that's [true] if the sum is ≥ 2ⁿ, and [false] + if there is no carry *) + CarrySub : forall (carry : bool) (x y : BoundedType), bool * BoundedType; + (** Ouputs [(x - y - if carry then 1 else 0) mod 2ⁿ], together + with a boolean that's [true] if the sum is negative, and [false] + if there is no borrow *) }. + +Inductive BoundedHalfType {n} {ops : ArchitectureBoundedOps n} := +| UpperHalf (_ : BoundedType) +| LowerHalf (_ : BoundedType). + +Definition UnderlyingBounded {n} {ops : ArchitectureBoundedOps n} (x : BoundedHalfType) + := match x with + | UpperHalf v => v + | LowerHalf v => v + end. + +Definition decode_half {n_over_two : size} {ops : ArchitectureBoundedOps (2 * n_over_two)%nat} (x : BoundedHalfType) : Z + := match x with + | UpperHalf v => decode v / 2^n_over_two + | LowerHalf v => (decode v) mod 2^n_over_two + end. + +Class ArchitectureBoundedFullMulOps n {ops : ArchitectureBoundedOps n} := + { Mul : BoundedType -> BoundedType -> BoundedType * BoundedType + (** Outputs [(high, low)] *) }. +Class ArchitectureBoundedHalfWidthMulOps n {ops : ArchitectureBoundedOps n} := + { HalfWidthMul : BoundedHalfType -> BoundedHalfType -> BoundedType }. + +Class ArchitectureBoundedProperties {n} (ops : ArchitectureBoundedOps n) := + { bounded_valid : BoundedType -> Prop; + decode_valid : forall v, + bounded_valid v + -> 0 <= decode v < 2^n; + encode_valid : forall z, + 0 <= z < 2^n + -> bounded_valid (encode z); + encode_correct : forall z, + 0 <= z < 2^n + -> decode (encode z) = z; + ShiftRight_valid : forall a high_low, + bounded_valid (fst high_low) -> bounded_valid (snd high_low) + -> bounded_valid (ShiftRight a high_low); + ShiftRight_correct : forall a high_low, + bounded_valid (fst high_low) -> bounded_valid (snd high_low) + -> decode (ShiftRight a high_low) = (decode (fst high_low) * 2^n + decode (snd high_low)) / 2^a; + ShiftLeft_fst_valid : forall a v, + bounded_valid v + -> bounded_valid (fst (ShiftLeft a v)); + ShiftLeft_snd_valid : forall a v, + bounded_valid v + -> bounded_valid (snd (ShiftLeft a v)); + ShiftLeft_fst_correct : forall a v, + bounded_valid v + -> decode (fst (ShiftLeft a v)) = (decode v * 2^a) mod 2^n; + ShiftLeft_snd_correct : forall a v, + bounded_valid v + -> decode (snd (ShiftLeft a v)) = ((decode v * 2^a) / 2^n) mod 2^n; + Mod2Pow_valid : forall a v, + bounded_valid v + -> bounded_valid (Mod2Pow a v); + Mod2Pow_correct : forall a v, + bounded_valid v + -> decode (Mod2Pow a v) = (decode v) mod 2^a; + CarryAdd_valid : forall c x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (snd (CarryAdd c x y)); + CarryAdd_fst_correct : forall c x y, + bounded_valid x -> bounded_valid y + -> fst (CarryAdd c x y) = (2^n <=? (decode x + decode y + if c then 1 else 0)); + CarryAdd_snd_correct : forall c x y, + bounded_valid x -> bounded_valid y + -> decode (snd (CarryAdd c x y)) = (decode x + decode y + if c then 1 else 0) mod 2^n; + CarrySub_valid : forall c x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (snd (CarrySub c x y)); + CarrySub_fst_correct : forall c x y, + bounded_valid x -> bounded_valid y + -> fst (CarrySub c x y) = ((decode x - decode y - if c then 1 else 0) bounded_valid y + -> decode (snd (CarrySub c x y)) = (decode x - decode y - if c then 1 else 0) mod 2^n }. + +Class ArchitectureBoundedFullMulProperties {n ops} (mops : @ArchitectureBoundedFullMulOps n ops) {props : ArchitectureBoundedProperties ops} := + { Mul_fst_valid : forall x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (fst (Mul x y)); + Mul_snd_valid : forall x y, + bounded_valid x -> bounded_valid y + -> bounded_valid (snd (Mul x y)); + Mul_high_correct : forall x y, + bounded_valid x -> bounded_valid y + -> decode (fst (Mul x y)) = (decode x * decode y) / 2^n; + Mul_low_correct : forall x y, + bounded_valid x -> bounded_valid y + -> decode (snd (Mul x y)) = (decode x * decode y) mod 2^n }. + +Class ArchitectureBoundedHalfWidthMulProperties {n_over_two ops} (mops : @ArchitectureBoundedHalfWidthMulOps (2 * n_over_two)%nat ops) {props : ArchitectureBoundedProperties ops} := + { HalfWidthMul_valid : forall x y, + bounded_valid (UnderlyingBounded x) -> bounded_valid (UnderlyingBounded y) + -> bounded_valid (HalfWidthMul x y); + HalfWidthMul_correct : forall x y, + bounded_valid (UnderlyingBounded x) -> bounded_valid (UnderlyingBounded y) + -> decode (HalfWidthMul x y) = (decode_half x * decode_half y)%Z }. -- cgit v1.2.3