aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jason Gross <jasongross9@gmail.com>2016-09-08 12:03:59 -0700
committerGravatar GitHub <noreply@github.com>2016-09-08 12:03:59 -0700
commita106b73720fc126023cbd0e0485271e2e118ee2d (patch)
tree05f1ba7d0e0dd692b23c5f95d37aa96c2253de15
parent3eab786d92b348c1dec33640eec3a02a5a86606b (diff)
parent35e7650fbc7ff87e945b7e5f7c06f27dd4bd119b (diff)
Merge pull request #63 from JasonGross/fancy-barrett-montgomery
Add Barrett and Montgomery for the 256-bit machine
-rw-r--r--_CoqProject3
-rw-r--r--src/Specific/FancyMachine256/Barrett.v128
-rw-r--r--src/Specific/FancyMachine256/Core.v356
-rw-r--r--src/Specific/FancyMachine256/Montgomery.v118
4 files changed, 605 insertions, 0 deletions
diff --git a/_CoqProject b/_CoqProject
index 0266fc122..6cbebb799 100644
--- a/_CoqProject
+++ b/_CoqProject
@@ -82,6 +82,9 @@ src/Spec/ModularWordEncoding.v
src/Spec/WeierstrassCurve.v
src/Specific/GF1305.v
src/Specific/GF25519.v
+src/Specific/FancyMachine256/Barrett.v
+src/Specific/FancyMachine256/Core.v
+src/Specific/FancyMachine256/Montgomery.v
src/Tactics/VerdiTactics.v
src/Tactics/Algebra_syntax/Nsatz.v
src/Util/AdditionChainExponentiation.v
diff --git a/src/Specific/FancyMachine256/Barrett.v b/src/Specific/FancyMachine256/Barrett.v
new file mode 100644
index 000000000..1683522e3
--- /dev/null
+++ b/src/Specific/FancyMachine256/Barrett.v
@@ -0,0 +1,128 @@
+Require Import Crypto.Specific.FancyMachine256.Core.
+Require Import Crypto.ModularArithmetic.BarrettReduction.ZBounded.
+Require Import Crypto.ModularArithmetic.BarrettReduction.ZHandbook.
+
+(** Useful for arithmetic in the field of integers modulo the order of the curve25519 base point *)
+Section expression.
+ Let b : Z := 2.
+ Let k : Z := 253.
+ Let offset : Z := 3.
+ Context (ops : fancy_machine.instructions (2 * 128)) (props : fancy_machine.arithmetic ops).
+ Context (m μ : Z)
+ (m_pos : 0 < m).
+ Let base_pos : 0 < b. reflexivity. Qed.
+ Context (k_good : m < b^k)
+ (μ_good : μ = b^(2*k) / m) (* [/] is [Z.div], which is truncated *).
+ Let offset_nonneg : 0 <= offset. unfold offset; omega. Qed.
+ Let k_big_enough : offset <= k. unfold offset, k; omega. Qed.
+ Context (m_small : 3 * m <= b^(k+offset))
+ (m_large : b^(k-offset) <= m + 1).
+ Context (H : 0 <= m < 2^256).
+ Let H' : 0 <= 250 <= 256. omega. Qed.
+ Let H'' : 0 < 250. omega. Qed.
+ Let props' := ZLikeProperties_of_ArchitectureBoundedOps ops m H 250 H' H''.
+ Let ops' := (ZLikeOps_of_ArchitectureBoundedOps ops m 250).
+ Local Existing Instances props' ops'.
+ Local Notation fst' := (@fst fancy_machine.W fancy_machine.W).
+ Local Notation snd' := (@snd fancy_machine.W fancy_machine.W).
+ Local Notation SmallT := (@ZBounded.SmallT (2 ^ 256) (2 ^ 250) m
+ (@ZLikeOps_of_ArchitectureBoundedOps 128 ops m _)).
+ Definition ldi' : load_immediate SmallT := _.
+ Let isldi : is_load_immediate ldi' := _.
+ Context (μ_range : 0 <= b ^ (2 * k) / m < b ^ (k + offset)).
+ Definition μ' : SmallT := ldi' μ.
+ Let μ'_eq : ZBounded.decode_small μ' = μ.
+ Proof.
+ unfold ZBounded.decode_small, ZLikeOps_of_ArchitectureBoundedOps, μ'.
+ apply (decode_load_immediate _ _).
+ rewrite μ_good; apply μ_range.
+ Qed.
+
+ Definition pre_f v
+ := (@barrett_reduce m b k μ offset m_pos base_pos μ_good offset_nonneg k_big_enough m_small m_large ops' props' μ' I μ'_eq (fst' v, snd' v)).
+
+ Local Arguments μ' / .
+ Local Arguments ldi' / .
+
+ Definition expression'
+ := Eval simpl in
+ (fun v => proj1_sig (pre_f v)).
+ Definition expression
+ := Eval cbv beta iota delta [expression' fst snd] in
+ fun v => let RegMod := fancy_machine.ldi m in
+ let RegMu := fancy_machine.ldi μ in
+ let RegZero := fancy_machine.ldi 0 in
+ expression' v.
+
+ Definition expression_eq v (H : 0 <= _ < _) : fancy_machine.decode (expression v) = _
+ := proj1 (proj2_sig (pre_f v) H).
+End expression.
+
+Section reflected.
+ Context (ops : fancy_machine.instructions (2 * 128)).
+ Definition rexpression : Syntax.Expr base_type (interp_base_type _) op (Arrow TZ (Arrow TZ (Arrow TW (Arrow TW (Tbase TW))))).
+ Proof.
+ let v := (eval cbv beta delta [expression] in (fun m μ x y => expression ops m μ (x, y))) in
+ let v := Reify v in
+ exact v.
+ Defined.
+
+ Definition rexpression_simple := Eval vm_compute in rexpression.
+
+ Context (m μ : Z)
+ (props : fancy_machine.arithmetic ops).
+
+ Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple m μ (fst v) (snd v).
+
+ Theorem sanity : result = expression ops m μ.
+ Proof.
+ reflexivity.
+ Qed.
+
+ Theorem correctness
+ (b : Z := 2)
+ (k : Z := 253)
+ (offset : Z := 3)
+ (H0 : 0 < m)
+ (H1 : μ = b^(2 * k) / m)
+ (H2 : 3 * m <= b^(k + offset))
+ (H3 : b^(k - offset) <= m + 1)
+ (H4 : 0 <= m < 2^(k + offset))
+ (H5 : 0 <= b^(2 * k) / m < b^(k + offset))
+ (v : tuple fancy_machine.W 2)
+ (H6 : 0 <= decode v < b^(2 * k))
+ : fancy_machine.decode (result v) = decode v mod m.
+ Proof.
+ rewrite sanity; destruct v.
+ apply expression_eq; assumption.
+ Qed.
+End reflected.
+
+Definition compiled_syntax
+ := Eval vm_compute in
+ (fun ops => AssembleSyntax ops (rexpression_simple _) (@RegMod :: @RegMuLow :: nil)%list).
+
+Print compiled_syntax.
+(* compiled_syntax =
+fun (_ : fancy_machine.instructions (2 * 128)) (var : base_type -> Type) =>
+λ x x0 : var TW,
+c.Rshi(x1, x0, x, 250),
+c.Mul128(x2, c.UpperHalf(x1), c.UpperHalf(RegMuLow)),
+c.Mul128(x3, c.UpperHalf(x1), c.LowerHalf(RegMuLow)),
+c.Mul128(x4, c.LowerHalf(x1), c.LowerHalf(RegMuLow)),
+c.Add(x6, x4, c.LeftShifted{x3, 128}),
+c.Addc(x8, x2, c.RightShifted{x3, 128}),
+c.Mul128(x9, c.UpperHalf(RegMuLow), c.LowerHalf(x1)),
+c.Add(_, x6, c.LeftShifted{x9, 128}),
+c.Addc(x13, x8, c.RightShifted{x9, 128}),
+c.Mul128(x14, c.LowerHalf(x13), c.LowerHalf(RegMod)),
+c.Mul128(x15, c.UpperHalf(x13), c.LowerHalf(RegMod)),
+c.Add(x17, x14, c.LeftShifted{x15, 128}),
+c.Mul128(x18, c.UpperHalf(RegMod), c.LowerHalf(x13)),
+c.Add(x20, x17, c.LeftShifted{x18, 128}),
+c.Sub(x22, x, x20),
+c.Addm(x23, x22, RegZero),
+c.Addm(x24, x23, RegZero),
+Return x24
+ : fancy_machine.instructions (2 * 128) -> forall var : base_type -> Type, syntax
+*)
diff --git a/src/Specific/FancyMachine256/Core.v b/src/Specific/FancyMachine256/Core.v
new file mode 100644
index 000000000..d11cfe6ad
--- /dev/null
+++ b/src/Specific/FancyMachine256/Core.v
@@ -0,0 +1,356 @@
+(** * A Fancy Machine with 256-bit registers *)
+Require Import Coq.Classes.RelationClasses Coq.Classes.Morphisms.
+Require Export Coq.ZArith.ZArith.
+Require Export Crypto.BoundedArithmetic.Interface.
+Require Export Crypto.BoundedArithmetic.ArchitectureToZLike.
+Require Export Crypto.BoundedArithmetic.ArchitectureToZLikeProofs.
+Require Export Crypto.Util.Tuple.
+Require Import Crypto.Util.Option Crypto.Util.Sigma Crypto.Util.Prod.
+Require Export Crypto.Reflection.Syntax.
+Require Import Crypto.Reflection.Linearize.
+Require Import Crypto.Reflection.CommonSubexpressionElimination.
+Require Export Crypto.Reflection.Reify.
+Require Export Crypto.Util.ZUtil.
+Require Export Crypto.Util.Notations.
+
+Open Scope Z_scope.
+Local Notation eta x := (fst x, snd x).
+Local Notation eta3 x := (eta (fst x), snd x).
+Local Notation eta3' x := (fst x, eta (snd x)).
+
+(** ** Reflective Assembly Syntax *)
+Section reflection.
+ Context (ops : fancy_machine.instructions (2 * 128)).
+ Local Set Boolean Equality Schemes.
+ Local Set Decidable Equality Schemes.
+ Inductive base_type := TZ | Tbool | TW.
+ Definition interp_base_type (v : base_type) : Type :=
+ match v with
+ | TZ => Z
+ | Tbool => bool
+ | TW => fancy_machine.W
+ end.
+ Local Notation tZ := (Tbase TZ).
+ Local Notation tbool := (Tbase Tbool).
+ Local Notation tW := (Tbase TW).
+ Local Open Scope ctype_scope.
+ Inductive op : flat_type base_type -> flat_type base_type -> Type :=
+ | OPldi : op tZ tW
+ | OPshrd : op (tW * tW * tZ) tW
+ | OPshl : op (tW * tZ) tW
+ | OPshr : op (tW * tZ) tW
+ | OPmkl : op (tW * tZ) tW
+ | OPadc : op (tW * tW * tbool) (tbool * tW)
+ | OPsubc : op (tW * tW * tbool) (tbool * tW)
+ | OPmulhwll : op (tW * tW) tW
+ | OPmulhwhl : op (tW * tW) tW
+ | OPmulhwhh : op (tW * tW) tW
+ | OPselc : op (tbool * tW * tW) tW
+ | OPaddm : op (tW * tW * tW) tW.
+
+ Definition interp_op src dst (f : op src dst)
+ : interp_flat_type_gen interp_base_type src -> interp_flat_type_gen interp_base_type dst
+ := match f in op s d return interp_flat_type_gen _ s -> interp_flat_type_gen _ d with
+ | OPldi => ldi
+ | OPshrd => fun xyz => let '(x, y, z) := eta3 xyz in shrd x y z
+ | OPshl => fun xy => let '(x, y) := eta xy in shl x y
+ | OPshr => fun xy => let '(x, y) := eta xy in shr x y
+ | OPmkl => fun xy => let '(x, y) := eta xy in mkl x y
+ | OPadc => fun xyz => let '(x, y, z) := eta3 xyz in adc x y z
+ | OPsubc => fun xyz => let '(x, y, z) := eta3 xyz in subc x y z
+ | OPmulhwll => fun xy => let '(x, y) := eta xy in mulhwll x y
+ | OPmulhwhl => fun xy => let '(x, y) := eta xy in mulhwhl x y
+ | OPmulhwhh => fun xy => let '(x, y) := eta xy in mulhwhh x y
+ | OPselc => fun xyz => let '(x, y, z) := eta3 xyz in selc x y z
+ | OPaddm => fun xyz => let '(x, y, z) := eta3 xyz in addm x y z
+ end.
+
+ Inductive SConstT := ZConst (_ : Z) | BoolConst (_ : bool) | INVALID_CONST.
+ Inductive op_code : Set :=
+ | SOPldi | SOPshrd | SOPshl | SOPshr | SOPmkl | SOPadc | SOPsubc
+ | SOPmulhwll | SOPmulhwhl | SOPmulhwhh | SOPselc | SOPaddm.
+
+ Definition symbolicify_const (t : base_type) : interp_base_type t -> SConstT
+ := match t with
+ | TZ => fun x => ZConst x
+ | Tbool => fun x => BoolConst x
+ | TW => fun x => INVALID_CONST
+ end.
+ Definition symbolicify_op s d (v : op s d) : op_code
+ := match v with
+ | OPldi => SOPldi
+ | OPshrd => SOPshrd
+ | OPshl => SOPshl
+ | OPshr => SOPshr
+ | OPmkl => SOPmkl
+ | OPadc => SOPadc
+ | OPsubc => SOPsubc
+ | OPmulhwll => SOPmulhwll
+ | OPmulhwhl => SOPmulhwhl
+ | OPmulhwhh => SOPmulhwhh
+ | OPselc => SOPselc
+ | OPaddm => SOPaddm
+ end.
+
+ Definition CSE {t} e := @CSE base_type SConstT op_code base_type_beq SConstT_beq op_code_beq internal_base_type_dec_bl interp_base_type op symbolicify_const symbolicify_op t e (fun _ => nil).
+End reflection.
+
+Ltac base_reify_op op op_head ::=
+ lazymatch op_head with
+ | @Interface.ldi => constr:(reify_op op op_head 1 OPldi)
+ | @Interface.shrd => constr:(reify_op op op_head 3 OPshrd)
+ | @Interface.shl => constr:(reify_op op op_head 2 OPshl)
+ | @Interface.shr => constr:(reify_op op op_head 2 OPshr)
+ | @Interface.mkl => constr:(reify_op op op_head 2 OPmkl)
+ | @Interface.adc => constr:(reify_op op op_head 3 OPadc)
+ | @Interface.subc => constr:(reify_op op op_head 3 OPsubc)
+ | @Interface.mulhwll => constr:(reify_op op op_head 2 OPmulhwll)
+ | @Interface.mulhwhl => constr:(reify_op op op_head 2 OPmulhwhl)
+ | @Interface.mulhwhh => constr:(reify_op op op_head 2 OPmulhwhh)
+ | @Interface.selc => constr:(reify_op op op_head 3 OPselc)
+ | @Interface.addm => constr:(reify_op op op_head 3 OPaddm)
+ end.
+Ltac base_reify_type T ::=
+ match T with
+ | Z => TZ
+ | bool => Tbool
+ | fancy_machine.W => TW
+ end.
+
+Ltac Reify' e := Reify.Reify' base_type (interp_base_type _) op e.
+Ltac Reify e :=
+ let v := Reify.Reify base_type (interp_base_type _) op e in
+ constr:(CSE _ (InlineConst (Linearize v))).
+(*Ltac Reify_rhs := Reify.Reify_rhs base_type (interp_base_type _) op (interp_op _).*)
+
+(** ** Raw Syntax Trees *)
+(** These are used solely for pretty-printing the expression tree in a
+ form that can be basically copy-pasted into other languages which
+ can be compiled for the Fancy Machine. Hypothetically, we could
+ add support for custom named identifiers, by carrying around
+ [string] identifiers and using them for pretty-printing... It
+ might also be possible to verify this layer, too, by adding a
+ partial interpretation function... *)
+Section syn.
+ Context {var : base_type -> Type}.
+ Inductive syntax :=
+ | RegPInv
+ | RegMod
+ | RegMuLow
+ | RegZero
+ | cConstZ : Z -> syntax
+ | cConstBool : bool -> syntax
+ | cLowerHalf : syntax -> syntax
+ | cUpperHalf : syntax -> syntax
+ | cLeftShifted : syntax -> Z -> syntax
+ | cRightShifted : syntax -> Z -> syntax
+ | cVar : var TW -> syntax
+ | cVarC : var Tbool -> syntax
+ | cBind : syntax -> (var TW -> syntax) -> syntax
+ | cBindCarry : syntax -> (var Tbool -> var TW -> syntax) -> syntax
+ | cMul128 : syntax -> syntax -> syntax
+ | cRshi : syntax -> syntax -> Z -> syntax
+ | cSelc : var Tbool -> syntax -> syntax -> syntax
+ | cAddc : var Tbool -> syntax -> syntax -> syntax
+ | cAddm : syntax -> syntax -> syntax
+ | cAdd : syntax -> syntax -> syntax
+ | cSub : syntax -> syntax -> syntax
+ | cPair : syntax -> syntax -> syntax
+ | cAbs {t} : (var t -> syntax) -> syntax
+ | cINVALID {T} (_ : T).
+End syn.
+
+Notation "'Return' x" := (cVar x) (at level 200).
+Notation "'c.Mul128' ( x , A , B ) , b" :=
+ (cBind (cMul128 A B) (fun x => b))
+ (at level 200, b at level 200, format "'c.Mul128' ( x , A , B ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , b" :=
+ (cBindCarry (cAdd A B) (fun _ x => b))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , b" :=
+ (cBindCarry (cAdd (cVar A) B) (fun _ x => b))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" :=
+ (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c A1 B1) (fun _ x1 => b)))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" :=
+ (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun _ x1 => b)))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" :=
+ (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c A1 B1) (fun _ x1 => b)))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , b" :=
+ (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun _ x1 => b)))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" :=
+ (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c A1 B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b))))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" :=
+ (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c A1 B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b))))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" :=
+ (cBindCarry (cAdd A B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b))))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" :=
+ (cBindCarry (cAdd (cVar A) B) (fun c x => cBindCarry (cAddc c (cVar A1) B1) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b))))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b").
+Notation "'c.Add' ( x , A , B ) , 'c.Addc' ( x1 , A1 , B1 ) , 'c.Selc' ( x2 , A2 , B2 ) , b" :=
+ (cBindCarry (cAdd (cVar A) (cVar B)) (fun c x => cBindCarry (cAddc c (cVar A1) (cVar B1)) (fun c1 x1 => cBind (cSelc c1 A2 B2) (fun x2 => b))))
+ (at level 200, b at level 200, format "'c.Add' ( x , A , B ) , '//' 'c.Addc' ( x1 , A1 , B1 ) , '//' 'c.Selc' ( x2 , A2 , B2 ) , '//' b").
+
+Notation "'c.Sub' ( x , A , B ) , b" :=
+ (cBindCarry (cSub A B) (fun _ x => b))
+ (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b").
+Notation "'c.Sub' ( x , A , B ) , b" :=
+ (cBindCarry (cSub (cVar A) B) (fun _ x => b))
+ (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b").
+Notation "'c.Sub' ( x , A , B ) , b" :=
+ (cBindCarry (cSub A (cVar B)) (fun _ x => b))
+ (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b").
+Notation "'c.Sub' ( x , A , B ) , b" :=
+ (cBindCarry (cSub (cVar A) (cVar B)) (fun _ x => b))
+ (at level 200, b at level 200, format "'c.Sub' ( x , A , B ) , '//' b").
+
+Notation "'c.Addm' ( x , A , B ) , b" :=
+ (cBind (cAddm A B) (fun x => b))
+ (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b").
+Notation "'c.Addm' ( x , A , B ) , b" :=
+ (cBind (cAddm A (cVar B)) (fun x => b))
+ (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b").
+Notation "'c.Addm' ( x , A , B ) , b" :=
+ (cBind (cAddm (cVar A) B) (fun x => b))
+ (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b").
+Notation "'c.Addm' ( x , A , B ) , b" :=
+ (cBind (cAddm (cVar A) (cVar B)) (fun x => b))
+ (at level 200, b at level 200, format "'c.Addm' ( x , A , B ) , '//' b").
+
+Notation "'c.Rshi' ( x , A , B , C ) , b" :=
+ (cBind (cRshi (cVar A) (cVar B) C) (fun x => b))
+ (at level 200, b at level 200, format "'c.Rshi' ( x , A , B , C ) , '//' b").
+
+Notation "'c.LowerHalf' ( x )" :=
+ (cLowerHalf x)
+ (at level 200, format "'c.LowerHalf' ( x )").
+Notation "'c.LowerHalf' ( x )" :=
+ (cLowerHalf (cVar x))
+ (at level 200, format "'c.LowerHalf' ( x )").
+Notation "'c.UpperHalf' ( x )" :=
+ (cUpperHalf x)
+ (at level 200, format "'c.UpperHalf' ( x )").
+Notation "'c.UpperHalf' ( x )" :=
+ (cUpperHalf (cVar x))
+ (at level 200, format "'c.UpperHalf' ( x )").
+Notation "'c.LeftShifted' { x , v }" :=
+ (cLeftShifted x v)
+ (at level 200, format "'c.LeftShifted' { x , v }").
+Notation "'c.LeftShifted' { x , v }" :=
+ (cLeftShifted (cVar x) v)
+ (at level 200, format "'c.LeftShifted' { x , v }").
+Notation "'c.RightShifted' { x , v }" :=
+ (cRightShifted x v)
+ (at level 200, format "'c.RightShifted' { x , v }").
+Notation "'c.RightShifted' { x , v }" :=
+ (cRightShifted (cVar x) v)
+ (at level 200, format "'c.RightShifted' { x , v }").
+Notation "'λ' x .. y , t" := (cAbs (fun x => .. (cAbs (fun y => t)) ..))
+ (at level 200, x binder, y binder, right associativity).
+
+Definition Syntax := forall var, @syntax var.
+
+(** Assemble a well-typed easily interpretable expression into a
+ syntax tree we can use for pretty-printing. *)
+Section assemble.
+ Context (ops : fancy_machine.instructions (2 * 128)).
+
+ Section with_var.
+ Context {var : base_type -> Type}.
+
+ Fixpoint assemble_syntax_const
+ {t}
+ : interp_flat_type_gen (interp_base_type _) t -> @syntax var
+ := match t return interp_flat_type_gen (interp_base_type _) t -> @syntax var with
+ | Tbase TZ => cConstZ
+ | Tbase Tbool => cConstBool
+ | Tbase t => fun _ => cINVALID t
+ | Prod A B => fun xy => cPair (@assemble_syntax_const A (fst xy))
+ (@assemble_syntax_const B (snd xy))
+ end.
+
+ Definition assemble_syntaxf_step
+ (assemble_syntaxf : forall {t} (v : @Syntax.exprf base_type (interp_base_type _) op (fun _ => @syntax var) t), @syntax var)
+ {t} (v : @Syntax.exprf base_type (interp_base_type _) op (fun _ => @syntax var) t) : @syntax var.
+ Proof.
+ refine match v return @syntax var with
+ | Syntax.Const t x => assemble_syntax_const x
+ | Syntax.Var _ x => x
+ | Syntax.Op t1 tR op args
+ => let v := @assemble_syntaxf t1 args in
+ (* handle both associativities of pairs in 3-ary
+ operators, in case we ever change the
+ associativity *)
+ match op, v with
+ | OPldi , cConstZ 0 => RegZero
+ | OPldi , cConstZ v => cINVALID v
+ | OPldi , RegZero => RegZero
+ | OPldi , RegMod => RegMod
+ | OPldi , RegMuLow => RegMuLow
+ | OPldi , RegPInv => RegPInv
+ | OPshrd , cPair x (cPair y (cConstZ n)) => cRshi x y n
+ | OPshrd , cPair (cPair x y) (cConstZ n) => cRshi x y n
+ | OPshl , cPair w (cConstZ n) => cLeftShifted w n
+ | OPshr , cPair w (cConstZ n) => cRightShifted w n
+ | OPmkl , _ => cINVALID op
+ | OPadc , cPair (cPair x y) (cVarC c) => cAddc c x y
+ | OPadc , cPair x (cPair y (cVarC c)) => cAddc c x y
+ | OPadc , cPair (cPair x y) (cConstBool false) => cAdd x y
+ | OPadc , cPair x (cPair y (cConstBool false)) => cAdd x y
+ | OPsubc , cPair (cPair x y) (cConstBool false) => cSub x y
+ | OPsubc , cPair x (cPair y (cConstBool false)) => cSub x y
+ | OPmulhwll, cPair x y => cMul128 (cLowerHalf x) (cLowerHalf y)
+ | OPmulhwhl, cPair x y => cMul128 (cUpperHalf x) (cLowerHalf y)
+ | OPmulhwhh, cPair x y => cMul128 (cUpperHalf x) (cUpperHalf y)
+ | OPselc , cPair (cVarC c) (cPair x y) => cSelc c x y
+ | OPselc , cPair (cPair (cVarC c) x) y => cSelc c x y
+ | OPaddm , cPair x (cPair y RegMod) => cAddm x y
+ | OPaddm , cPair (cPair x y) RegMod => cAddm x y
+ | _, _ => cINVALID op
+ end
+ | Syntax.Let tx ex _ eC
+ => let ex' := @assemble_syntaxf _ ex in
+ let eC' := fun x => @assemble_syntaxf _ (eC x) in
+ let special := match ex' with
+ | RegZero as ex'' | RegMuLow as ex'' | RegMod as ex'' | RegPInv as ex''
+ | cUpperHalf _ as ex'' | cLowerHalf _ as ex''
+ | cLeftShifted _ _ as ex''
+ | cRightShifted _ _ as ex''
+ => Some ex''
+ | _ => None
+ end in
+ match special, tx return (interp_flat_type_gen _ tx -> _) -> _ with
+ | Some x, Tbase _ => fun eC' => eC' x
+ | _, Tbase TW
+ => fun eC' => cBind ex' (fun x => eC' (cVar x))
+ | _, Prod (Tbase Tbool) (Tbase TW)
+ => fun eC' => cBindCarry ex' (fun c x => eC' (cVarC c, cVar x))
+ | _, _
+ => fun _ => cINVALID (fun x : Prop => x)
+ end eC'
+ | Syntax.Pair _ ex _ ey
+ => cPair (@assemble_syntaxf _ ex) (@assemble_syntaxf _ ey)
+ end.
+ Defined.
+
+ Fixpoint assemble_syntaxf {t} v {struct v} : @syntax var
+ := @assemble_syntaxf_step (@assemble_syntaxf) t v.
+ Fixpoint assemble_syntax {t} (v : @Syntax.expr base_type (interp_base_type _) op (fun _ => @syntax var) t) (args : list (@syntax var)) {struct v}
+ : @syntax var
+ := match v, args return @syntax var with
+ | Syntax.Return _ x, _ => assemble_syntaxf x
+ | Syntax.Abs _ _ f, nil => cAbs (fun x => @assemble_syntax _ (f (cVar x)) args)
+ | Syntax.Abs _ _ f, cons v vs => @assemble_syntax _ (f v) vs
+ end.
+ End with_var.
+
+ Definition AssembleSyntax {t} (v : Syntax.Expr _ _ _ t) (args : list Syntax) : Syntax
+ := fun var => @assemble_syntax var t (v _) (List.map (fun f => f var) args).
+End assemble.
diff --git a/src/Specific/FancyMachine256/Montgomery.v b/src/Specific/FancyMachine256/Montgomery.v
new file mode 100644
index 000000000..a9a50f773
--- /dev/null
+++ b/src/Specific/FancyMachine256/Montgomery.v
@@ -0,0 +1,118 @@
+Require Import Crypto.Specific.FancyMachine256.Core.
+Require Import Crypto.ModularArithmetic.Montgomery.ZBounded.
+Require Import Crypto.ModularArithmetic.Montgomery.ZProofs.
+
+Section expression.
+ Context (ops : fancy_machine.instructions (2 * 128)) (props : fancy_machine.arithmetic ops) (modulus : Z) (m' : Z) (Hm : modulus <> 0) (H : 0 <= modulus < 2^256) (Hm' : 0 <= m' < 2^256).
+ Let H' : 0 <= 256 <= 256. omega. Qed.
+ Let H'' : 0 < 256. omega. Qed.
+ Let props' := ZLikeProperties_of_ArchitectureBoundedOps ops modulus H 256 H' H''.
+ Let ops' := (ZLikeOps_of_ArchitectureBoundedOps ops modulus 256).
+ Local Notation fst' := (@fst fancy_machine.W fancy_machine.W).
+ Local Notation snd' := (@snd fancy_machine.W fancy_machine.W).
+ Definition ldi' : load_immediate
+ (@ZBounded.SmallT (2 ^ 256) (2 ^ 256) modulus
+ (@ZLikeOps_of_ArchitectureBoundedOps 128 ops modulus 256)) := _.
+ Let isldi : is_load_immediate ldi' := _.
+ Definition pre_f := (fun v => (reduce_via_partial (2^256) modulus (props := props') (ldi' m') I Hm (fst' v, snd' v))).
+ Definition f := (fun v => proj1_sig (pre_f v)).
+
+ Local Arguments proj1_sig _ _ !_ / .
+ Local Arguments ZBounded.CarryAdd / .
+ Local Arguments ZBounded.ConditionalSubtract / .
+ Local Arguments ZBounded.ConditionalSubtractModulus / .
+ Local Arguments ZLikeOps_of_ArchitectureBoundedOps / .
+ Local Arguments ZBounded.DivBy_SmallBound / .
+ Local Arguments f / .
+ Local Arguments pre_f / .
+ Local Arguments ldi' / .
+ Local Arguments reduce_via_partial / .
+
+ Definition expression'
+ := Eval simpl in f.
+ Definition expression
+ := Eval cbv beta delta [expression' fst snd] in
+ fun v => let RegMod := fancy_machine.ldi modulus in
+ let RegPInv := fancy_machine.ldi m' in
+ let RegZero := fancy_machine.ldi 0 in
+ expression' v.
+ Definition expression_eq v : fancy_machine.decode (expression v) = _
+ := proj1 (proj2_sig (pre_f v) I).
+ Definition expression_correct
+ R' HR0 HR1
+ v
+ Hv
+ : fancy_machine.decode (expression v) = _
+ := @ZBounded.reduce_via_partial_correct (2^256) modulus _ props' (ldi' m') I Hm R' HR0 HR1 v I Hv.
+End expression.
+
+Section reflected.
+ Context (ops : fancy_machine.instructions (2 * 128)).
+ Definition rexpression : Syntax.Expr base_type (interp_base_type _) op (Arrow TZ (Arrow TZ (Arrow TW (Arrow TW (Tbase TW))))).
+ Proof.
+ let v := (eval cbv beta delta [expression] in (fun modulus m' x y => expression ops modulus m' (x, y))) in
+ let v := Reify v in
+ exact v.
+ Defined.
+
+ Definition rexpression_simple := Eval vm_compute in rexpression.
+
+ Context (modulus m' : Z)
+ (props : fancy_machine.arithmetic ops).
+
+ Let result (v : tuple fancy_machine.W 2) := Syntax.Interp (interp_op _) rexpression_simple modulus m' (fst v) (snd v).
+
+ Theorem sanity : result = expression ops modulus m'.
+ Proof.
+ reflexivity.
+ Qed.
+
+ Local Infix "≡₂₅₆" := (Z.equiv_modulo (2^256)).
+ Local Infix "≡" := (Z.equiv_modulo modulus).
+
+ Theorem correctness
+ R' (* modular inverse of 2^256 *)
+ (H0 : modulus <> 0)
+ (H1 : 0 <= modulus < 2^256)
+ (H2 : 0 <= m' < 2^256)
+ (H3 : 2^256 * R' ≡ 1)
+ (H4 : modulus * m' ≡₂₅₆ -1)
+ (v : tuple fancy_machine.W 2)
+ (H5 : 0 <= decode v <= 2^256 * modulus)
+ : fancy_machine.decode (result v) = (decode v * R') mod modulus.
+ Proof.
+ replace m' with (fancy_machine.decode (fancy_machine.ldi m')) in H4
+ by (apply decode_load_immediate; trivial; exact _).
+ rewrite sanity; destruct v; apply expression_correct; assumption.
+ Qed.
+End reflected.
+
+Definition compiled_syntax
+ := Eval vm_compute in
+ (fun ops => AssembleSyntax ops (rexpression_simple _) (@RegMod :: @RegPInv :: nil)%list).
+
+Print compiled_syntax.
+(* compiled_syntax =
+fun (_ : fancy_machine.instructions (2 * 128)) (var : base_type -> Type) =>
+λ x x0 : var TW,
+c.Mul128(x1, c.LowerHalf(x), c.LowerHalf(RegPInv)),
+c.Mul128(x2, c.UpperHalf(x), c.LowerHalf(RegPInv)),
+c.Add(x4, x1, c.LeftShifted{x2, 128}),
+c.Mul128(x5, c.UpperHalf(RegPInv), c.LowerHalf(x)),
+c.Add(x7, x4, c.LeftShifted{x5, 128}),
+c.Mul128(x8, c.UpperHalf(x7), c.UpperHalf(RegMod)),
+c.Mul128(x9, c.UpperHalf(x7), c.LowerHalf(RegMod)),
+c.Mul128(x10, c.LowerHalf(x7), c.LowerHalf(RegMod)),
+c.Add(x12, x10, c.LeftShifted{x9, 128}),
+c.Addc(x14, x8, c.RightShifted{x9, 128}),
+c.Mul128(x15, c.UpperHalf(RegMod), c.LowerHalf(x7)),
+c.Add(x17, x12, c.LeftShifted{x15, 128}),
+c.Addc(x19, x14, c.RightShifted{x15, 128}),
+c.Add(_, x, x17),
+c.Addc(x23, x0, x19),
+c.Selc(x24, RegMod, RegZero),
+c.Sub(x26, x23, x24),
+c.Addm(x27, x26, RegZero),
+Return x27
+ : fancy_machine.instructions (2 * 128) -> forall var : base_type -> Type, syntax
+*)