aboutsummaryrefslogtreecommitdiff
path: root/src/BoundedArithmetic
diff options
context:
space:
mode:
authorGravatar Jason Gross <jagro@google.com>2016-08-12 11:45:08 -0700
committerGravatar Jason Gross <jagro@google.com>2016-08-23 16:01:45 -0700
commitdc295c74a191d2ad9ab56a4792391a4c68a42e5d (patch)
tree5002f7435b5978aaa64d62b6912dcb3f41d2c92d /src/BoundedArithmetic
parent07b18ae2cb1122f395bffdf706ad37248bc5d4dc (diff)
Rework interface to support rewriting database
Diffstat (limited to 'src/BoundedArithmetic')
-rw-r--r--src/BoundedArithmetic/ArchitectureToZLike.v35
-rw-r--r--src/BoundedArithmetic/DoubleBounded.v14
-rw-r--r--src/BoundedArithmetic/Interface.v240
3 files changed, 222 insertions, 67 deletions
diff --git a/src/BoundedArithmetic/ArchitectureToZLike.v b/src/BoundedArithmetic/ArchitectureToZLike.v
index ebb21bd4b..01387e969 100644
--- a/src/BoundedArithmetic/ArchitectureToZLike.v
+++ b/src/BoundedArithmetic/ArchitectureToZLike.v
@@ -13,8 +13,8 @@ Local Open Scope type_scope.
Local Coercion Z.of_nat : nat >-> Z.
Section fancy_machine_p256_montgomery_foundation.
- Context {n_over_two : size}.
- Local Notation n := (2 * n_over_two)%nat.
+ Context {n_over_two : Z}.
+ Local Notation n := (2 * n_over_two)%Z.
Context (ops : fancy_machine.instructions n) (modulus : Z).
Definition two_list_to_tuple {A B} (x : A * list B)
@@ -26,7 +26,7 @@ Section fancy_machine_p256_montgomery_foundation.
| (a, [b0; b1]) => (a, (b0, b1))
| _ => I
end.
-
+(*
(* make all machine-specific constructions here, preferrably as
thing wrappers around generic constructions *)
Local Instance DoubleArchitectureBoundedOps : ArchitectureBoundedOps (2 * n)%nat
@@ -103,24 +103,23 @@ Section fancy_machine_p256_montgomery_foundation.
End full_from_half.
Local Existing Instance DoubleArchitectureBoundedFullMulOpsOfHalfWidthMulOps.
-
+*)
Axiom admit : forall {T}, T.
- Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : size)
+ Global Instance ZLikeOps_of_ArchitectureBoundedOps (smaller_bound_exp : Z)
: ZLikeOps (2^n) (2^smaller_bound_exp) modulus :=
- { LargeT := @BoundedType (2 * n)%nat _;
- SmallT := @BoundedType n _;
- modulus_digits := encode modulus;
- decode_large := decode;
+ { LargeT := fancy_machine.W * fancy_machine.W;
+ SmallT := fancy_machine.W;
+ modulus_digits := ldi modulus;
+ decode_large := _;
decode_small := decode;
Mod_SmallBound v := snd v;
DivBy_SmallBound v := fst v;
- DivBy_SmallerBound v := ShiftRight smaller_bound_exp v;
- Mul x y := @Interface.Mul n _ _ x y;
- CarryAdd x y := Interface.CarryAdd false x y;
- CarrySubSmall x y := Interface.CarrySub false x y;
- ConditionalSubtract b x := admit;
- ConditionalSubtractModulus y := admit }.
-End ops.
-
-End with_single_ops. \ No newline at end of file
+ DivBy_SmallerBound v := shrd (fst v) (snd v) smaller_bound_exp;
+ Mul x y := _ (*mulhwll (ldi 0, x) (ldi 0, y)*);
+ CarryAdd x y := _ (*adc x y false*);
+ CarrySubSmall x y := subc x y false;
+ ConditionalSubtract b x := let v := selc b (ldi 0) (ldi modulus) in snd (subc x v false);
+ ConditionalSubtractModulus y := addm y (ldi 0) (ldi modulus) }.
+ Abort.
+End fancy_machine_p256_montgomery_foundation.
diff --git a/src/BoundedArithmetic/DoubleBounded.v b/src/BoundedArithmetic/DoubleBounded.v
index 0d9c8e860..c09006e2b 100644
--- a/src/BoundedArithmetic/DoubleBounded.v
+++ b/src/BoundedArithmetic/DoubleBounded.v
@@ -1,6 +1,8 @@
(*** Implementing Large Bounded Arithmetic via pairs *)
Require Import Coq.ZArith.ZArith Coq.Lists.List.
Require Import Crypto.BoundedArithmetic.Interface.
+Require Import Crypto.BaseSystem.
+Require Import Crypto.BaseSystemProofs.
Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Notations.
@@ -15,7 +17,7 @@ Local Notation eta x := (fst x, snd x).
Section generic_constructions.
Definition ripple_carry {T} (f : T -> T -> bool -> bool * T)
- (carry : bool) (xs ys : list T) : bool * list T
+ (xs ys : list T) (carry : bool) : bool * list T
:= List.fold_right
(fun x_y carry_zs => let '(x, y) := eta x_y in
let '(carry, zs) := eta carry_zs in
@@ -24,6 +26,16 @@ Section generic_constructions.
(carry, nil)
(List.combine xs ys).
+ Section ripple_carry_adc.
+ Context {n W} {decode : decoder n W} (adc : add_with_carry W).
+
+ Global Instance ripple_carry_add_with_carry : add_with_carry (list W)
+ := {| Interface.adc := ripple_carry adc |}.
+ (*
+ Global Instance ripple_carry_is_add_with_carry {is_adc : is_add_with_carry adc}
+ : is_add_with_carry ripple_carry_add_with_carry.*)
+ End ripple_carry_adc.
+
(* TODO: Would it made sense to make generic-width shift operations here? *)
(* FUTURE: here go proofs about [ripple_carry] with [f] that satisfies [is_add_with_carry] *)
diff --git a/src/BoundedArithmetic/Interface.v b/src/BoundedArithmetic/Interface.v
index f7906ccb5..cda72967c 100644
--- a/src/BoundedArithmetic/Interface.v
+++ b/src/BoundedArithmetic/Interface.v
@@ -3,100 +3,244 @@ Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.ZUtil.
Require Import Crypto.Util.Notations.
-Local Open Scope nat_scope.
Local Open Scope Z_scope.
Local Open Scope type_scope.
-Definition size := nat.
+Create HintDb push_decode discriminated.
+Create HintDb pull_decode discriminated.
+Hint Extern 1 => progress autorewrite with push_decode in * : push_decode.
+Hint Extern 1 => progress autorewrite with pull_decode in * : pull_decode.
-Local Coercion Z.of_nat : nat >-> Z.
+Class decoder (n : Z) W :=
+ { decode : W -> Z }.
+Coercion decode : decoder >-> Funclass.
+Global Arguments decode {n W _} _.
+
+Class is_decode {n W} (decode : decoder n W) :=
+ decode_range : forall x, 0 <= decode x < 2^n.
Section InstructionGallery.
- Context (n:Z) (* bit-width of width of W *)
- {W : Type}(* previously [BoundedType], [W] for word *)
- (decode : W -> Z).
+ Context (n : Z) (* bit-width of width of [W] *)
+ {W : Type} (* bounded type, [W] for word *)
+ (Wdecoder : decoder n W).
Local Notation imm := Z (only parsing). (* immediate (compile-time) argument *)
- Class is_load_immediate (ldi:imm->W) :=
+ Record load_immediate := { ldi :> imm -> W }.
+
+ Class is_load_immediate {ldi : load_immediate} :=
decode_load_immediate : forall x, 0 <= x < 2^n -> decode (ldi x) = x.
- Class is_shift_right_doubleword_immediate (shrd:W->W->imm->W) :=
+ Record shift_right_doubleword_immediate := { shrd :> W -> W -> imm -> W }.
+
+ Class is_shift_right_doubleword_immediate (shrd : shift_right_doubleword_immediate) :=
decode_shift_right_doubleword :
forall high low count,
0 <= count < n
-> decode (shrd high low count) = (((decode high << n) + decode low) >> count) mod 2^n.
- Class is_shift_left_immediate (shl:W->imm->W) :=
+ Record shift_left_immediate := { shl :> W -> imm -> W }.
+
+ Class is_shift_left_immediate (shl : shift_left_immediate) :=
decode_shift_left_immediate :
forall r count, 0 <= count < n -> decode (shl r count) = (decode r << count) mod 2^n.
- Class is_spread_left_immediate (sprl:W->imm->W*W(*high, low*)) :=
+ Record spread_left_immediate := { sprl :> W -> imm -> W * W (* [(high, low)] *) }.
+
+ Class is_spread_left_immediate (sprl : spread_left_immediate) :=
{
- fst_spread_left_immediate : forall r count, 0 <= count < n ->
- decode (fst (sprl r count)) = (decode r << count) >> n;
- snd_spread_left_immediate : forall r count, 0 <= count < n ->
- decode (snd (sprl r count)) = (decode r << count) mod 2^n
+ decode_fst_spread_left_immediate : forall r count,
+ 0 <= count < n
+ -> decode (fst (sprl r count)) = (decode r << count) >> n;
+ decode_snd_spread_left_immediate : forall r count,
+ 0 <= count < n
+ -> decode (snd (sprl r count)) = (decode r << count) mod 2^n
}.
- Class is_mask_keep_low (mkl:W->imm->W) :=
+ Record mask_keep_low := { mkl :> W -> imm -> W }.
+
+ Class is_mask_keep_low (mkl : mask_keep_low) :=
decode_mask_keep_low : forall r count,
0 <= count < n -> decode (mkl r count) = decode r mod 2^count.
Local Notation bit b := (if b then 1 else 0).
- Class is_add_with_carry (adc:W->W->bool->bool*W) :=
+
+ Record add_with_carry := { adc :> W -> W -> bool -> bool * W }.
+
+ Class is_add_with_carry (adc : add_with_carry) :=
{
- fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n;
- snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)
+ bit_fst_add_with_carry : forall x y c, bit (fst (adc x y c)) = (decode x + decode y + bit c) >> n;
+ decode_snd_add_with_carry : forall x y c, decode (snd (adc x y c)) = (decode x + decode y + bit c) mod (2^n)
}.
+ Record sub_with_carry := { subc :> W -> W -> bool -> bool * W }.
+
Class is_sub_with_carry (subc:W->W->bool->bool*W) :=
{
- fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0);
- snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n
+ fst_sub_with_carry : forall x y c, fst (subc x y c) = ((decode x - decode y - bit c) <? 0);
+ decode_snd_sub_with_carry : forall x y c, decode (snd (subc x y c)) = (decode x - decode y - bit c) mod 2^n
}.
- Class is_mul (mul:W->W->W) :=
- decode_mul : forall x y, decode (mul x y) = decode x * decode y mod 2^n.
+ Record multiply := { mul :> W -> W -> W }.
+
+ Class is_mul (mul : multiply) :=
+ decode_mul : forall x y, decode (mul x y) = (decode x * decode y) mod 2^n.
+
+ Record multiply_low_low := { mulhwll :> W -> W -> W }.
+ Record multiply_high_low := { mulhwhl :> W -> W -> W }.
+ Record multiply_high_high := { mulhwhh :> W -> W -> W }.
- Class is_mul_low_low (w:Z) (mulhwll:W->W->W) :=
+ Class is_mul_low_low (w:Z) (mulhwll : multiply_low_low) :=
decode_mul_low_low :
forall x y, decode (mulhwll x y) = ((decode x mod 2^w) * (decode y mod 2^w)) mod 2^n.
- Class is_mul_high_low (w:Z) (mulhwhl:W->W->W) :=
+ Class is_mul_high_low (w:Z) (mulhwhl : multiply_high_low) :=
decode_mul_high_low :
forall x y, decode (mulhwhl x y) = ((decode x >> w) * (decode y mod 2^w)) mod 2^n.
- Class is_mul_high_high (w:Z) (mulhwhh:W->W->W) :=
+ Class is_mul_high_high (w:Z) (mulhwhh : multiply_high_high) :=
decode_mul_high_high :
forall x y, decode (mulhwhh x y) = ((decode x >> w) * (decode y >> w)) mod 2^n.
+
+ Record select_conditional := { selc :> bool -> W -> W -> W }.
+
+ Class is_select_conditional (selc : select_conditional) :=
+ decode_select_conditional : forall b x y,
+ decode (selc b x y) = if b then decode x else decode y.
+
+ Record add_modulo := { addm :> W -> W -> W (* modulus *) -> W }.
+
+ Class is_add_modulo (addm : add_modulo) :=
+ decode_add_modulo : forall x y modulus,
+ decode (addm x y modulus) = (decode x + decode y) mod (decode modulus).
End InstructionGallery.
+Global Arguments load_immediate : clear implicits.
+Global Arguments shift_right_doubleword_immediate : clear implicits.
+Global Arguments shift_left_immediate : clear implicits.
+Global Arguments spread_left_immediate : clear implicits.
+Global Arguments mask_keep_low : clear implicits.
+Global Arguments add_with_carry : clear implicits.
+Global Arguments sub_with_carry : clear implicits.
+Global Arguments multiply : clear implicits.
+Global Arguments multiply_low_low : clear implicits.
+Global Arguments multiply_high_low : clear implicits.
+Global Arguments multiply_high_high : clear implicits.
+Global Arguments select_conditional : clear implicits.
+Global Arguments add_modulo : clear implicits.
+Global Arguments ldi {_ _} _.
+Global Arguments shrd {_ _} _ _ _.
+Global Arguments shl {_ _} _ _.
+Global Arguments sprl {_ _} _ _.
+Global Arguments mkl {_ _} _ _.
+Global Arguments adc {_ _} _ _ _.
+Global Arguments subc {_ _} _ _ _.
+Global Arguments mul {_ _} _ _.
+Global Arguments mulhwll {_ _} _ _.
+Global Arguments mulhwhl {_ _} _ _.
+Global Arguments mulhwhh {_ _} _ _.
+Global Arguments selc {_ _} _ _ _.
+Global Arguments addm {_ _} _ _ _.
+
+Existing Class load_immediate.
+Existing Class shift_right_doubleword_immediate.
+Existing Class shift_left_immediate.
+Existing Class spread_left_immediate.
+Existing Class mask_keep_low.
+Existing Class add_with_carry.
+Existing Class sub_with_carry.
+Existing Class multiply.
+Existing Class multiply_low_low.
+Existing Class multiply_high_low.
+Existing Class multiply_high_high.
+Existing Class select_conditional.
+Existing Class add_modulo.
+
+Global Arguments is_decode {_ _} _.
+Global Arguments is_load_immediate {_ _ _} _.
+Global Arguments is_shift_right_doubleword_immediate {_ _ _} _.
+Global Arguments is_shift_left_immediate {_ _ _} _.
+Global Arguments is_spread_left_immediate {_ _ _} _.
+Global Arguments is_mask_keep_low {_ _ _} _.
+Global Arguments is_add_with_carry {_ _ _} _.
+Global Arguments is_sub_with_carry {_ _ _} _.
+Global Arguments is_mul {_ _ _} _.
+Global Arguments is_mul_low_low {_ _ _} _ _.
+Global Arguments is_mul_high_low {_ _ _} _ _.
+Global Arguments is_mul_high_high {_ _ _} _ _.
+Global Arguments is_select_conditional {_ _ _} _.
+Global Arguments is_add_modulo {_ _ _} _.
+
+Ltac bounded_sovlver_tac :=
+ solve [ eassumption | typeclasses eauto | omega | auto 6 using decode_range with typeclass_instances omega ].
+
+Hint Rewrite @decode_load_immediate @decode_shift_right_doubleword @decode_shift_left_immediate @decode_fst_spread_left_immediate @decode_snd_spread_left_immediate @decode_mask_keep_low @bit_fst_add_with_carry @decode_snd_add_with_carry @fst_sub_with_carry @decode_snd_sub_with_carry @decode_mul @decode_mul_low_low @decode_mul_high_low @decode_mul_high_high @decode_select_conditional @decode_add_modulo using bounded_sovlver_tac : push_decode.
+
+Ltac push_decode :=
+ repeat first [ erewrite !decode_load_immediate by bounded_sovlver_tac
+ | erewrite !decode_shift_right_doubleword by bounded_sovlver_tac
+ | erewrite !decode_shift_left_immediate by bounded_sovlver_tac
+ | erewrite !decode_fst_spread_left_immediate by bounded_sovlver_tac
+ | erewrite !decode_snd_spread_left_immediate by bounded_sovlver_tac
+ | erewrite !decode_mask_keep_low by bounded_sovlver_tac
+ | erewrite !bit_fst_add_with_carry by bounded_sovlver_tac
+ | erewrite !decode_snd_add_with_carry by bounded_sovlver_tac
+ | erewrite !fst_sub_with_carry by bounded_sovlver_tac
+ | erewrite !decode_snd_sub_with_carry by bounded_sovlver_tac
+ | erewrite !decode_mul by bounded_sovlver_tac
+ | erewrite !decode_mul_low_low by bounded_sovlver_tac
+ | erewrite !decode_mul_high_low by bounded_sovlver_tac
+ | erewrite !decode_mul_high_high by bounded_sovlver_tac
+ | erewrite !decode_select_conditional by bounded_sovlver_tac
+ | erewrite !decode_add_modulo by bounded_sovlver_tac ].
+Ltac pull_decode :=
+ repeat first [ erewrite <- !decode_load_immediate by bounded_sovlver_tac
+ | erewrite <- !decode_shift_right_doubleword by bounded_sovlver_tac
+ | erewrite <- !decode_shift_left_immediate by bounded_sovlver_tac
+ | erewrite <- !decode_fst_spread_left_immediate by bounded_sovlver_tac
+ | erewrite <- !decode_snd_spread_left_immediate by bounded_sovlver_tac
+ | erewrite <- !decode_mask_keep_low by bounded_sovlver_tac
+ | erewrite <- !bit_fst_add_with_carry by bounded_sovlver_tac
+ | erewrite <- !decode_snd_add_with_carry by bounded_sovlver_tac
+ | erewrite <- !fst_sub_with_carry by bounded_sovlver_tac
+ | erewrite <- !decode_snd_sub_with_carry by bounded_sovlver_tac
+ | erewrite <- !decode_mul by bounded_sovlver_tac
+ | erewrite <- !decode_mul_low_low by bounded_sovlver_tac
+ | erewrite <- !decode_mul_high_low by bounded_sovlver_tac
+ | erewrite <- !decode_mul_high_high by bounded_sovlver_tac
+ | erewrite <- !decode_select_conditional by bounded_sovlver_tac
+ | erewrite <- !decode_add_modulo by bounded_sovlver_tac ].
+
Module fancy_machine.
Local Notation imm := Z (only parsing).
- Class instructions (n:Z) :=
+
+ Class instructions (n : Z) :=
{
W : Type (* [n]-bit word *);
- ldi : imm -> W;
- shrd : W->W->imm -> W;
- sprl : W->imm -> W*W;
- mkl : W->imm -> W;
- adc : W->W->bool -> bool*W;
- subc : W->W->bool -> bool*W
+ decode :> decoder n W;
+ ldi :> load_immediate W;
+ shrd :> shift_right_doubleword_immediate W;
+ shl :> shift_left_immediate W;
+ mkl :> mask_keep_low W;
+ adc :> add_with_carry W;
+ subc :> sub_with_carry W;
+ mulhwll :> multiply_low_low W;
+ mulhwhl :> multiply_high_low W;
+ mulhwhh :> multiply_high_high W;
+ selc :> select_conditional W;
+ addm :> add_modulo W
}.
- Class arithmetic {n} (ops:instructions n) :=
+ Class arithmetic {n_over_two} (ops:instructions (2 * n_over_two)) :=
{
- decode : W -> Z;
- decode_range : forall x, 0 <= decode x < 2^n;
- load_immediate : is_load_immediate n decode ldi;
- shift_right_doubleword_immediate : is_shift_right_doubleword_immediate n decode shrd;
- spread_left_immediate : is_spread_left_immediate n decode sprl;
- mask_keep_low : is_mask_keep_low n decode mkl;
- add_with_carry : is_add_with_carry n decode adc;
- sub_with_carry : is_sub_with_carry n decode subc
+ decode_range :> is_decode decode;
+ load_immediate :> is_load_immediate ldi;
+ shift_right_doubleword_immediate :> is_shift_right_doubleword_immediate shrd;
+ shift_left_immediate :> is_shift_left_immediate shl;
+ mask_keep_low :> is_mask_keep_low mkl;
+ add_with_carry :> is_add_with_carry adc;
+ sub_with_carry :> is_sub_with_carry subc;
+ multiply_low_low :> is_mul_low_low n_over_two mulhwll;
+ multiply_high_low :> is_mul_high_low n_over_two mulhwhl;
+ multiply_high_high :> is_mul_high_high n_over_two mulhwhh;
+ select_conditional :> is_select_conditional selc;
+ add_modulo :> is_add_modulo addm
}.
- Global Existing Instance load_immediate.
- Global Existing Instance shift_right_doubleword_immediate.
- Global Existing Instance spread_left_immediate.
- Global Existing Instance mask_keep_low.
- Global Existing Instance add_with_carry.
- Global Existing Instance sub_with_carry.
-End fancy_machine. \ No newline at end of file
+End fancy_machine.