aboutsummaryrefslogtreecommitdiff
path: root/src/Arithmetic/Partition.v
blob: 2d2fb87fab9ba23322458bae5ebfa88e5b1048d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
Require Import Coq.ZArith.ZArith.
Require Import Coq.Lists.List.
Require Import Coq.Structures.Orders.
Require Import Crypto.Arithmetic.Core.
Require Import Crypto.Util.ListUtil.
Require Import Crypto.Util.ZUtil.EquivModulo.
Require Import Crypto.Util.ZUtil.Modulo Crypto.Util.ZUtil.Div.

Require Import Crypto.Util.Notations.
Import ListNotations Weight. Local Open Scope Z_scope.

Section Partition.
  Context weight {wprops : @weight_properties weight}.

  Definition partition n x :=
    map (fun i => (x mod weight (S i)) / weight i) (seq 0 n).

  Lemma partition_step n x :
    partition (S n) x = partition n x ++ [(x mod weight (S n)) / weight n].
  Proof using Type.
    cbv [partition]. rewrite seq_snoc.
    autorewrite with natsimplify push_map. reflexivity.
  Qed.

  Lemma length_partition n x : length (partition n x) = n.
  Proof using Type. cbv [partition]; distr_length. Qed.
  Hint Rewrite length_partition : distr_length.

  Lemma eval_partition n x :
    Positional.eval weight n (partition n x) = x mod (weight n).
  Proof using wprops.
    induction n; intros.
    { cbn. rewrite (weight_0); auto with zarith. }
    { rewrite (Z.div_mod (x mod weight (S n)) (weight n)) by auto with zarith.
      rewrite <-Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto with zarith).
      rewrite partition_step, Positional.eval_snoc with (n:=n) by distr_length.
      omega. }
  Qed.

  Lemma partition_Proper n :
    Proper (Z.equiv_modulo (weight n) ==> eq) (partition n).
  Proof using wprops.
    cbv [Proper Z.equiv_modulo respectful].
    intros x y Hxy; induction n; intros.
    { reflexivity. }
    { assert (Hxyn : x mod weight n = y mod weight n).
      { erewrite (Znumtheory.Zmod_div_mod _ (weight (S n)) x), (Znumtheory.Zmod_div_mod _ (weight (S n)) y), Hxy
          by (try apply Z.mod_divide; auto with zarith);
          reflexivity. }
      rewrite !partition_step, IHn by eauto.
      rewrite (Z.div_mod (x mod weight (S n)) (weight n)), (Z.div_mod (y mod weight (S n)) (weight n)) by auto with zarith.
      rewrite <-!Znumtheory.Zmod_div_mod by (try apply Z.mod_divide; auto with zarith).
      rewrite Hxy, Hxyn; reflexivity. }
  Qed.

  (* This is basically a shortcut for:
       apply partition_Proper; [ | cbv [Z.equiv_modulo] *)
  Lemma partition_eq_mod x y n :
    x mod weight n = y mod weight n ->
    partition n x = partition n y.
  Proof. apply partition_Proper. Qed.

  Lemma nth_default_partition d n x i :
    (i < n)%nat ->
    nth_default d (partition n x) i = x mod weight (S i) / weight i.
  Proof.
    cbv [partition]; intros.
    rewrite map_nth_default with (x:=0%nat) by distr_length.
    autorewrite with push_nth_default natsimplify. reflexivity.
  Qed.

  Fixpoint recursive_partition n i x :=
    match n with
    | O => []
    | S n' => x mod (weight (S i) / weight i) :: recursive_partition n' (S i) (x / (weight (S i) / weight i))
    end.

  Lemma recursive_partition_equiv' n : forall x j,
      map (fun i => x mod weight (S i) / weight i) (seq j n) = recursive_partition n j (x / weight j).
  Proof using wprops.
    induction n; [reflexivity|].
    intros; cbn. rewrite IHn.
    pose proof (@weight_positive _ wprops j).
    pose proof (@weight_divides _ wprops j).
    f_equal;
      repeat match goal with
             | _ => rewrite Z.mod_pull_div by auto with zarith 
             | _ => rewrite weight_multiples by auto with zarith
             | _ => progress autorewrite with zsimplify_fast zdiv_to_mod pull_Zdiv
             | _ => reflexivity
             end.
  Qed.

  Lemma recursive_partition_equiv n x :
    partition n x = recursive_partition n 0%nat x.
  Proof using wprops.
    cbv [partition]. rewrite recursive_partition_equiv'.
    rewrite weight_0 by auto; autorewrite with zsimplify_fast.
    reflexivity.
  Qed.

  Lemma length_recursive_partition n : forall i x,
      length (recursive_partition n i x) = n.
  Proof using Type.
    induction n; cbn [recursive_partition]; [reflexivity | ].
    intros; distr_length; auto.
  Qed.

  Lemma drop_high_to_length_partition n m x :
    (n <= m)%nat ->
    Positional.drop_high_to_length n (partition m x) = partition n x.
  Proof using Type.
    cbv [Positional.drop_high_to_length partition]; intros.
    autorewrite with push_firstn.
    rewrite Nat.min_l by omega.
    reflexivity.
  Qed.

  Lemma partition_0 n : partition n 0 = Positional.zeros n.
  Proof.
    cbv [partition].
    erewrite Positional.zeros_ext_map with (p:=seq 0 n) by distr_length.
    apply map_ext; intros.
    autorewrite with zsimplify; reflexivity.
  Qed.

End Partition.
Hint Rewrite length_partition length_recursive_partition : distr_length.
Hint Rewrite eval_partition using (solve [auto; distr_length]) : push_eval.