diff options
author | Jason Gross <jagro@google.com> | 2016-08-12 11:45:08 -0700 |
---|---|---|
committer | Jason Gross <jagro@google.com> | 2016-08-23 16:01:45 -0700 |
commit | dc295c74a191d2ad9ab56a4792391a4c68a42e5d (patch) | |
tree | 5002f7435b5978aaa64d62b6912dcb3f41d2c92d /src/BoundedArithmetic | |
parent | 07b18ae2cb1122f395bffdf706ad37248bc5d4dc (diff) |
Rework interface to support rewriting database
Diffstat (limited to 'src/BoundedArithmetic')
-rw-r--r-- | src/BoundedArithmetic/ArchitectureToZLike.v | 35 | ||||
-rw-r--r-- | src/BoundedArithmetic/DoubleBounded.v | 14 | ||||
-rw-r--r-- | src/BoundedArithmetic/Interface.v | 240 |
3 files changed, 222 insertions, 67 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v index ebb21bd4b..01387e969 100644 --- a/src/BoundedArithmetic/ArchitectureToZLike.v +++ b/src/BoundedArithmetic/ArchitectureToZLike.v @@ -13,8 +13,8 @@ Local Open Scope type_scope. Local Coercion Z.of_nat : nat >-> Z. Section fancy_machine_p256_montgomery_foundation. - Context {n_over_two : size}. - Local Notation n := (2 * n_over_two)%nat. + Context {n_over_two : Z}. + Local Notation n := (2 * n_over_two)%Z. Context (ops : fancy_machine.instructions n) (modulus : Z). Definition two_list_to_tuple {A B} (x : A * list B) @@ -26,7 +26,7 @@ Section fancy_machine_p256_montgomery_foundation. | (a, [b0; b1]) => (a, (b0, b1)) | _ => I end. - +(* (* make all machine-specific constructions here, preferrably as thing wrappers around generic constructions *) Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat @@ -103,24 +103,23 @@ Section fancy_machine_p256_montgomery_foundation. End full_from_half. Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps. - +*) Axiom admit : forall {T}, T. - Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : size) + Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z) : ZLikeOps (2^n) (2^smaller_bound_exp) modulus := - { LargeT := @BoundedType (2 * n)%nat _; - SmallT := @BoundedType n _; - modulus_digits := encode modulus; - decode_large := decode; + { LargeT := fancy_machine.W * fancy_machine.W; + SmallT := fancy_machine.W; + modulus_digits := ldi modulus; + decode_large := _; decode_small := decode; Mod_SmallBound v := snd v; DivBy_SmallBound v := fst v; - DivBy_SmallerBound v := ShiftRight smaller_bound_exp v; - Mul x y := @Interface.Mul n _ _ x y; - CarryAdd x y := Interface.CarryAdd false x y; - CarrySubSmall x y := Interface.CarrySub false x y; - ConditionalSubtract b x := admit; - ConditionalSubtractModulus y := admit }. -End ops. - -End with_single_ops.
\ No newline at end of file + DivBy_SmallerBound v := shrd (fst v) (snd v) smaller_bound_exp; + Mul x y := _ (*mulhwll (ldi 0, x) (ldi 0, y)*); + CarryAdd x y := _ (*adc x y false*); + CarrySubSmall x y := subc x y false; + ConditionalSubtract b x := let v := selc b (ldi 0) (ldi modulus) in snd (subc x v false); + ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }. + Abort. +End fancy_machine_p256_montgomery_foundation. diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v index 0d9c8e860..c09006e2b 100644 --- a/src/BoundedArithmetic/DoubleBounded.v +++ b/src/BoundedArithmetic/DoubleBounded.v @@ -1,6 +1,8 @@ (*** Implementing Large Bounded Arithmetic via pairs *) Require Import Coq.ZArith.ZArith Coq.Lists.List. Require Import Crypto.BoundedArithmetic.Interface. +Require Import Crypto.BaseSystem. +Require Import Crypto.BaseSystemProofs. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. @@ -15,7 +17,7 @@ Local Notation eta x := (fst x, snd x). Section generic_constructions. Definition ripple_carry {T} (f : T -> T -> bool -> bool * T) - (carry : bool) (xs ys : list T) : bool * list T + (xs ys : list T) (carry : bool) : bool * list T := List.fold_right (fun x_y carry_zs => let '(x, y) := eta x_y in let '(carry, zs) := eta carry_zs in @@ -24,6 +26,16 @@ Section generic_constructions. (carry, nil) (List.combine xs ys). + Section ripple_carry_adc. + Context {n W} {decode : decoder n W} (adc : add_with_carry W). + + Global Instance ripple_carry_add_with_carry : add_with_carry (list W) + := {| Interface.adc := ripple_carry adc |}. + (* + Global Instance ripple_carry_is_add_with_carry {is_adc : is_add_with_carry adc} + : is_add_with_carry ripple_carry_add_with_carry.*) + End ripple_carry_adc. + (* TODO: Would it made sense to make generic-width shift operations here? *) (* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *) diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v index f7906ccb5..cda72967c 100644 --- a/src/BoundedArithmetic/Interface.v +++ b/src/BoundedArithmetic/Interface.v @@ -3,100 +3,244 @@ Require Import Coq.ZArith.ZArith. Require Import Crypto.Util.ZUtil. Require Import Crypto.Util.Notations. -Local Open Scope nat_scope. Local Open Scope Z_scope. Local Open Scope type_scope. -Definition size := nat. +Create HintDb push_decode discriminated. +Create HintDb pull_decode discriminated. +Hint Extern 1 => progress autorewrite with push_decode in * : push_decode. +Hint Extern 1 => progress autorewrite with pull_decode in * : pull_decode. -Local Coercion Z.of_nat : nat >-> Z. +Class decoder (n : Z) W := + { decode : W -> Z }. +Coercion decode : decoder >-> Funclass. +Global Arguments decode {n W _} _. + +Class is_decode {n W} (decode : decoder n W) := + decode_range : forall x, 0 <= decode x < 2^n. Section InstructionGallery. - Context (n:Z) (* bit-width of width of W *) - {W : Type}(* previously [BoundedType], [W] for word *) - (decode : W -> Z). + Context (n : Z) (* bit-width of width of [W] *) + {W : Type} (* bounded type, [W] for word *) + (Wdecoder : decoder n W). Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *) - Class is_load_immediate (ldi:imm->W) := + Record load_immediate := { ldi :> imm -> W }. + + Class is_load_immediate {ldi : load_immediate} := decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x. - Class is_shift_right_doubleword_immediate (shrd:W->W->imm->W) := + Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }. + + Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) := decode_shift_right_doubleword : forall high low count, 0 <= count < n -> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n. - Class is_shift_left_immediate (shl:W->imm->W) := + Record shift_left_immediate := { shl :> W -> imm -> W }. + + Class is_shift_left_immediate (shl : shift_left_immediate) := decode_shift_left_immediate : forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n. - Class is_spread_left_immediate (sprl:W->imm->W*W(*high, low*)) := + Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }. + + Class is_spread_left_immediate (sprl : spread_left_immediate) := { - fst_spread_left_immediate : forall r count, 0 <= count < n -> - decode (fst (sprl r count)) = (decode r << count) >> n; - snd_spread_left_immediate : forall r count, 0 <= count < n -> - decode (snd (sprl r count)) = (decode r << count) mod 2^n + decode_fst_spread_left_immediate : forall r count, + 0 <= count < n + -> decode (fst (sprl r count)) = (decode r << count) >> n; + decode_snd_spread_left_immediate : forall r count, + 0 <= count < n + -> decode (snd (sprl r count)) = (decode r << count) mod 2^n }. - Class is_mask_keep_low (mkl:W->imm->W) := + Record mask_keep_low := { mkl :> W -> imm -> W }. + + Class is_mask_keep_low (mkl : mask_keep_low) := decode_mask_keep_low : forall r count, 0 <= count < n -> decode (mkl r count) = decode r mod 2^count. Local Notation bit b := (if b then 1 else 0). - Class is_add_with_carry (adc:W->W->bool->bool*W) := + + Record add_with_carry := { adc :> W -> W -> bool -> bool * W }. + + Class is_add_with_carry (adc : add_with_carry) := { - fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; - snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) + bit_fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n; + decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n) }. + Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }. + Class is_sub_with_carry (subc:W->W->bool->bool*W) := { - fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0); - snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n + fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0); + decode_snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n }. - Class is_mul (mul:W->W->W) := - decode_mul : forall x y, decode (mul x y) = decode x * decode y mod 2^n. + Record multiply := { mul :> W -> W -> W }. + + Class is_mul (mul : multiply) := + decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n. + + Record multiply_low_low := { mulhwll :> W -> W -> W }. + Record multiply_high_low := { mulhwhl :> W -> W -> W }. + Record multiply_high_high := { mulhwhh :> W -> W -> W }. - Class is_mul_low_low (w:Z) (mulhwll:W->W->W) := + Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) := decode_mul_low_low : forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n. - Class is_mul_high_low (w:Z) (mulhwhl:W->W->W) := + Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) := decode_mul_high_low : forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n. - Class is_mul_high_high (w:Z) (mulhwhh:W->W->W) := + Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) := decode_mul_high_high : forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n. + + Record select_conditional := { selc :> bool -> W -> W -> W }. + + Class is_select_conditional (selc : select_conditional) := + decode_select_conditional : forall b x y, + decode (selc b x y) = if b then decode x else decode y. + + Record add_modulo := { addm :> W -> W -> W (* modulus *) -> W }. + + Class is_add_modulo (addm : add_modulo) := + decode_add_modulo : forall x y modulus, + decode (addm x y modulus) = (decode x + decode y) mod (decode modulus). End InstructionGallery. +Global Arguments load_immediate : clear implicits. +Global Arguments shift_right_doubleword_immediate : clear implicits. +Global Arguments shift_left_immediate : clear implicits. +Global Arguments spread_left_immediate : clear implicits. +Global Arguments mask_keep_low : clear implicits. +Global Arguments add_with_carry : clear implicits. +Global Arguments sub_with_carry : clear implicits. +Global Arguments multiply : clear implicits. +Global Arguments multiply_low_low : clear implicits. +Global Arguments multiply_high_low : clear implicits. +Global Arguments multiply_high_high : clear implicits. +Global Arguments select_conditional : clear implicits. +Global Arguments add_modulo : clear implicits. +Global Arguments ldi {_ _} _. +Global Arguments shrd {_ _} _ _ _. +Global Arguments shl {_ _} _ _. +Global Arguments sprl {_ _} _ _. +Global Arguments mkl {_ _} _ _. +Global Arguments adc {_ _} _ _ _. +Global Arguments subc {_ _} _ _ _. +Global Arguments mul {_ _} _ _. +Global Arguments mulhwll {_ _} _ _. +Global Arguments mulhwhl {_ _} _ _. +Global Arguments mulhwhh {_ _} _ _. +Global Arguments selc {_ _} _ _ _. +Global Arguments addm {_ _} _ _ _. + +Existing Class load_immediate. +Existing Class shift_right_doubleword_immediate. +Existing Class shift_left_immediate. +Existing Class spread_left_immediate. +Existing Class mask_keep_low. +Existing Class add_with_carry. +Existing Class sub_with_carry. +Existing Class multiply. +Existing Class multiply_low_low. +Existing Class multiply_high_low. +Existing Class multiply_high_high. +Existing Class select_conditional. +Existing Class add_modulo. + +Global Arguments is_decode {_ _} _. +Global Arguments is_load_immediate {_ _ _} _. +Global Arguments is_shift_right_doubleword_immediate {_ _ _} _. +Global Arguments is_shift_left_immediate {_ _ _} _. +Global Arguments is_spread_left_immediate {_ _ _} _. +Global Arguments is_mask_keep_low {_ _ _} _. +Global Arguments is_add_with_carry {_ _ _} _. +Global Arguments is_sub_with_carry {_ _ _} _. +Global Arguments is_mul {_ _ _} _. +Global Arguments is_mul_low_low {_ _ _} _ _. +Global Arguments is_mul_high_low {_ _ _} _ _. +Global Arguments is_mul_high_high {_ _ _} _ _. +Global Arguments is_select_conditional {_ _ _} _. +Global Arguments is_add_modulo {_ _ _} _. + +Ltac bounded_sovlver_tac := + solve [ eassumption | typeclasses eauto | omega | auto 6 using decode_range with typeclass_instances omega ]. + +Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo using bounded_sovlver_tac : push_decode. + +Ltac push_decode := + repeat first [ erewrite !decode_load_immediate by bounded_sovlver_tac + | erewrite !decode_shift_right_doubleword by bounded_sovlver_tac + | erewrite !decode_shift_left_immediate by bounded_sovlver_tac + | erewrite !decode_fst_spread_left_immediate by bounded_sovlver_tac + | erewrite !decode_snd_spread_left_immediate by bounded_sovlver_tac + | erewrite !decode_mask_keep_low by bounded_sovlver_tac + | erewrite !bit_fst_add_with_carry by bounded_sovlver_tac + | erewrite !decode_snd_add_with_carry by bounded_sovlver_tac + | erewrite !fst_sub_with_carry by bounded_sovlver_tac + | erewrite !decode_snd_sub_with_carry by bounded_sovlver_tac + | erewrite !decode_mul by bounded_sovlver_tac + | erewrite !decode_mul_low_low by bounded_sovlver_tac + | erewrite !decode_mul_high_low by bounded_sovlver_tac + | erewrite !decode_mul_high_high by bounded_sovlver_tac + | erewrite !decode_select_conditional by bounded_sovlver_tac + | erewrite !decode_add_modulo by bounded_sovlver_tac ]. +Ltac pull_decode := + repeat first [ erewrite <- !decode_load_immediate by bounded_sovlver_tac + | erewrite <- !decode_shift_right_doubleword by bounded_sovlver_tac + | erewrite <- !decode_shift_left_immediate by bounded_sovlver_tac + | erewrite <- !decode_fst_spread_left_immediate by bounded_sovlver_tac + | erewrite <- !decode_snd_spread_left_immediate by bounded_sovlver_tac + | erewrite <- !decode_mask_keep_low by bounded_sovlver_tac + | erewrite <- !bit_fst_add_with_carry by bounded_sovlver_tac + | erewrite <- !decode_snd_add_with_carry by bounded_sovlver_tac + | erewrite <- !fst_sub_with_carry by bounded_sovlver_tac + | erewrite <- !decode_snd_sub_with_carry by bounded_sovlver_tac + | erewrite <- !decode_mul by bounded_sovlver_tac + | erewrite <- !decode_mul_low_low by bounded_sovlver_tac + | erewrite <- !decode_mul_high_low by bounded_sovlver_tac + | erewrite <- !decode_mul_high_high by bounded_sovlver_tac + | erewrite <- !decode_select_conditional by bounded_sovlver_tac + | erewrite <- !decode_add_modulo by bounded_sovlver_tac ]. + Module fancy_machine. Local Notation imm := Z (only parsing). - Class instructions (n:Z) := + + Class instructions (n : Z) := { W : Type (* [n]-bit word *); - ldi : imm -> W; - shrd : W->W->imm -> W; - sprl : W->imm -> W*W; - mkl : W->imm -> W; - adc : W->W->bool -> bool*W; - subc : W->W->bool -> bool*W + decode :> decoder n W; + ldi :> load_immediate W; + shrd :> shift_right_doubleword_immediate W; + shl :> shift_left_immediate W; + mkl :> mask_keep_low W; + adc :> add_with_carry W; + subc :> sub_with_carry W; + mulhwll :> multiply_low_low W; + mulhwhl :> multiply_high_low W; + mulhwhh :> multiply_high_high W; + selc :> select_conditional W; + addm :> add_modulo W }. - Class arithmetic {n} (ops:instructions n) := + Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) := { - decode : W -> Z; - decode_range : forall x, 0 <= decode x < 2^n; - load_immediate : is_load_immediate n decode ldi; - shift_right_doubleword_immediate : is_shift_right_doubleword_immediate n decode shrd; - spread_left_immediate : is_spread_left_immediate n decode sprl; - mask_keep_low : is_mask_keep_low n decode mkl; - add_with_carry : is_add_with_carry n decode adc; - sub_with_carry : is_sub_with_carry n decode subc + decode_range :> is_decode decode; + load_immediate :> is_load_immediate ldi; + shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd; + shift_left_immediate :> is_shift_left_immediate shl; + mask_keep_low :> is_mask_keep_low mkl; + add_with_carry :> is_add_with_carry adc; + sub_with_carry :> is_sub_with_carry subc; + multiply_low_low :> is_mul_low_low n_over_two mulhwll; + multiply_high_low :> is_mul_high_low n_over_two mulhwhl; + multiply_high_high :> is_mul_high_high n_over_two mulhwhh; + select_conditional :> is_select_conditional selc; + add_modulo :> is_add_modulo addm }. - Global Existing Instance load_immediate. - Global Existing Instance shift_right_doubleword_immediate. - Global Existing Instance spread_left_immediate. - Global Existing Instance mask_keep_low. - Global Existing Instance add_with_carry. - Global Existing Instance sub_with_carry. -End fancy_machine.
\ No newline at end of file +End fancy_machine. |