blob: 3d92e7b4ea0bc8de95cc81af44c2b4421b22b34f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
(** * PHOAS Syntax for expression trees on ℤ *)
Require Import Coq.ZArith.ZArith.
Require Import Crypto.Util.Curry.
Require Import Crypto.Compilers.SmartMap.
Require Import Crypto.Compilers.Syntax.
Require Import Crypto.Compilers.TypeUtil.
Require Import Crypto.Util.Option.
Require Import Crypto.Util.ZUtil.Definitions.
Require Import Crypto.Util.IdfunWithAlt.
Require Import Crypto.Util.NatUtil. (* for nat_beq for equality schemes *)
Export Syntax.Notations.
Local Set Boolean Equality Schemes.
Local Set Decidable Equality Schemes.
Inductive base_type := TZ | TBool.
Local Notation tZ := (Tbase TZ).
Local Notation tBool := (Tbase TBool).
Definition interp_base_type (v : base_type) : Set :=
match v with
| TZ => Z
| TBool => bool
end.
Inductive op : flat_type base_type -> flat_type base_type -> Set :=
| AddGetCarry : op (tuple tZ 3) (tuple tZ 2)
| MulSplitAtBitwidth : op (tuple tZ 3) (tuple tZ 2)
| Zselect : op (tuple tZ 3) (tuple tZ 1)
| Zmul : op (tuple tZ 2) (tuple tZ 1)
| Zadd : op (tuple tZ 2) (tuple tZ 1)
| Zopp : op (tuple tZ 1) (tuple tZ 1)
| Zshiftr : op (tuple tZ 2) (tuple tZ 1)
| Zshiftl : op (tuple tZ 2) (tuple tZ 1)
| Zland : op (tuple tZ 2) (tuple tZ 1)
| Zlor : op (tuple tZ 2) (tuple tZ 1)
| Zmodulo : op (tuple tZ 2) (tuple tZ 1)
| Zdiv : op (tuple tZ 2) (tuple tZ 1)
| Zlog2 : op (tuple tZ 1) (tuple tZ 1)
| Zpow : op (tuple tZ 2) (tuple tZ 1)
| Zones : op (tuple tZ 1) (tuple tZ 1)
| Zeqb : op (tuple tZ 2) (tuple tBool 1)
| ConstZ (v : BinNums.Z) : op (tuple tZ 0) (tuple tZ 1)
| ConstBool (v : bool) : op (tuple tZ 0) (tuple tBool 1)
| BoolCase {T} : op (Prod (Prod tBool T) T) T.
Definition Const {t} : interp_base_type t -> op Unit (Tbase t)
:= match t with
| TZ => ConstZ
| Tbool => ConstBool
end.
Definition interp_op {s d} (opv : op s d) : interp_flat_type interp_base_type s -> interp_flat_type interp_base_type d
:= match opv with
| AddGetCarry => curry3 Z.add_get_carry
| MulSplitAtBitwidth => curry3 Z.mul_split_at_bitwidth
| Zselect => curry3 Z.zselect
| Zmul => curry2 Z.mul
| Zadd => curry2 Z.add
| Zopp => Z.opp
| Zshiftr => curry2 Z.shiftr
| Zshiftl => curry2 Z.shiftl
| Zland => curry2 Z.land
| Zlor => curry2 Z.lor
| Zmodulo => curry2 Z.modulo
| Zdiv => curry2 Z.div
| Zlog2 => Z.log2
| Zpow => curry2 Z.pow
| Zones => Z.ones
| Zeqb => curry2 Z.eqb
| ConstZ v => fun _ => v
| ConstBool v => fun _ => v
| BoolCase T => fun '(b, t, f) => if b then t else f
end.
Notation Expr := (Expr base_type op).
Notation Interp := (Interp interp_op).
|