// RUN: %dafny /compile:0 "%s" > "%t"
// RUN: %diff "%s.expect" "%t"
abstract module Monad {
type M
function method Return(x: A): M
function method Bind(m: M, f:A -> M):M
reads f.reads
requires forall a :: f.requires(a)
// return x >>= f = f x
lemma LeftIdentity(x : A, f : A -> M)
requires forall a :: f.requires(a)
ensures Bind(Return(x),f) == f(x)
// m >>= return = m
lemma RightIdentity(m : M)
ensures Bind(m,Return) == m
// (m >>= f) >>= g = m >>= (x => f(x) >>= g)
lemma Associativity(m : M, f:A -> M, g: B -> M)
requires forall a :: f.requires(a)
requires forall b :: g.requires(b)
ensures Bind(Bind(m,f),g) ==
Bind(m,x reads f.reads(x)
reads g.reads
requires f.requires(x)
requires forall b :: g.requires(b) => Bind(f(x),g))
}
module Identity refines Monad {
datatype M = I(A)
function method Return(x: A): M
{ I(x) }
function method Bind(m: M, f:A -> M):M
{
var I(x) := m; f(x)
}
lemma LeftIdentity(x : A, f : A -> M)
{
}
lemma RightIdentity(m : M)
{
assert Bind(m,Return) == m;
}
lemma Associativity(m : M, f:A -> M, g: B -> M)
{
assert
Bind(Bind(m,f),g) ==
Bind(m,x reads f.reads(x)
reads g.reads
requires f.requires(x)
requires forall b :: g.requires(b) => Bind(f(x),g));
}
}
module Maybe refines Monad {
datatype M = Just(A) | Nothing
function method Return(x: A): M
{ Just(x) }
function method Bind(m: M, f:A -> M):M
{
match m
case Nothing => Nothing
case Just(x) => f(x)
}
lemma LeftIdentity(x : A, f : A -> M)
{
}
lemma RightIdentity(m : M)
{
assert Bind(m,Return) == m;
}
lemma Associativity(m : M, f:A -> M, g: B -> M)
{
assert
Bind(Bind(m,f),g) ==
Bind(m,x reads f.reads(x)
reads g.reads
requires f.requires(x)
requires forall b :: g.requires(b) => Bind(f(x),g));
}
}
module List refines Monad {
datatype M = Cons(hd: A,tl: M) | Nil
function method Return(x: A): M
{ Cons(x,Nil) }
function method Concat(xs: M, ys: M): M
{
match xs
case Nil => ys
case Cons(x,xs) => Cons(x,Concat(xs,ys))
}
function method Join(xss: M>) : M
{
match xss
case Nil => Nil
case Cons(xs,xss) => Concat(xs,Join(xss))
}
function method Map(xs: M, f: A -> B):M
reads f.reads;
requires forall a :: f.requires(a);
{
match xs
case Nil => Nil
case Cons(x,xs) => Cons(f(x),Map(xs,f))
}
function method Bind(m: M, f:A -> M):M
{
Join(Map(m,f))
}
lemma LeftIdentity(x : A, f : A -> M)
{
calc {
Bind(Return(x),f);
== Join(Map(Cons(x,Nil),f));
== Join(Cons(f(x),Nil));
== Concat(f(x),Nil);
== { assert forall xs : M :: Concat(xs,Nil) == xs; }
f(x);
}
}
lemma RightIdentity(m : M)
{
match m
case Nil => calc {
Bind(Nil,Return);
== Join(Map(Nil,Return));
== Join(Nil);
== Nil;
== m;
}
case Cons(x,xs) =>
calc {
Bind(m,Return);
== Bind(Cons(x,xs),Return);
== Join(Map(Cons(x,xs),Return));
== Join(Cons(Return(x),Map(xs,Return)));
== Concat(Return(x),Join(Map(xs,Return)));
== { RightIdentity(xs); }
Concat(Return(x),xs);
== Concat(Cons(x,Nil),xs);
== Cons(x,xs);
== m;
}
}
lemma ConcatAssociativity(xs : M, ys : M, zs: M)
ensures Concat(Concat(xs,ys),zs) == Concat(xs,Concat(ys,zs));
{}
lemma BindMorphism(xs : M, ys: M, f : A -> M)
requires forall a :: f.requires(a);
ensures Bind(Concat(xs,ys),f) == Concat(Bind(xs,f),Bind(ys,f));
{
match xs
case Nil => calc {
Bind(Concat(Nil,ys),f);
== Bind(ys,f);
== Concat(Nil,Bind(ys,f));
== Concat(Bind(Nil,f),Bind(ys,f));
}
case Cons(z,zs) => calc {
Bind(Concat(xs,ys),f);
== Bind(Concat(Cons(z,zs),ys),f);
== Concat(f(z),Bind(Concat(zs,ys),f));
== { BindMorphism(zs,ys,f); }
Concat(f(z),Concat(Bind(zs,f),Bind(ys,f)));
== { ConcatAssociativity(f(z),Bind(zs,f),Bind(ys,f)); }
Concat(Concat(f(z),Join(Map(zs,f))),Bind(ys,f));
== Concat(Bind(Cons(z,zs),f),Bind(ys,f));
== Concat(Bind(xs,f),Bind(ys,f));
}
}
lemma Associativity(m : M, f:A -> M, g: B -> M)
{
match m
case Nil => calc {
Bind(Bind(m,f),g);
== Bind(Bind(Nil,f),g);
== Bind(Nil,g);
== Nil;
== Bind(Nil,x reads f.reads(x)
reads g.reads
requires f.requires(x)
requires forall b :: g.requires(b) => Bind(f(x),g));
== Bind(m,x reads f.reads(x)
reads g.reads
requires f.requires(x)
requires forall b :: g.requires(b) => Bind(f(x),g));
}
case Cons(x,xs) => calc {
Bind(Bind(m,f),g);
== Bind(Bind(Cons(x,xs),f),g);
== Bind(Concat(f(x),Bind(xs,f)),g);
== { BindMorphism(f(x),Bind(xs,f),g); }
Concat(Bind(f(x),g),Bind(Bind(xs,f),g));
== { Associativity(xs,f,g); }
Concat(Bind(f(x),g),Join(Map(xs,y reads f.reads(y)
reads g.reads
requires f.requires(y)
requires forall b :: g.requires(b) => Bind(f(y),g))));
== Join(Cons(Bind(f(x),g),Map(xs,y reads f.reads(y)
reads g.reads
requires f.requires(y)
requires forall b :: g.requires(b) => Bind(f(y),g))));
== Join(Map(Cons(x,xs),y reads f.reads(y)
reads g.reads
requires f.requires(y)
requires forall b :: g.requires(b) => Bind(f(y),g)));
== Bind(Cons(x,xs),y reads f.reads(y)
reads g.reads
requires f.requires(y)
requires forall b :: g.requires(b) => Bind(f(y),g));
== Bind(m,x reads f.reads(x)
reads g.reads
requires f.requires(x)
requires forall b :: g.requires(b) => Bind(f(x),g));
}
}
}