// RUN: %dafny /compile:0 "%s" > "%t" // RUN: %diff "%s.expect" "%t" abstract module Monad { type M function method Return(x: A): M function method Bind(m: M, f:A -> M):M reads f.reads; requires forall a :: f.requires(a); // return x >>= f = f x lemma LeftIdentity(x : A, f : A -> M) requires forall a :: f.requires(a); ensures Bind(Return(x),f) == f(x); // m >>= return = m lemma RightIdentity(m : M) ensures Bind(m,Return) == m; // (m >>= f) >>= g = m >>= (x => f(x) >>= g) lemma Associativity(m : M, f:A -> M, g: B -> M) requires forall a :: f.requires(a); requires forall b :: g.requires(b); ensures Bind(Bind(m,f),g) == Bind(m,x reads f.reads(x) reads g.reads requires f.requires(x) requires forall b :: g.requires(b) => Bind(f(x),g)); } module Identity refines Monad { datatype M = I(A) function method Return(x: A): M { I(x) } function method Bind(m: M, f:A -> M):M { var I(x) := m; f(x) } lemma LeftIdentity(x : A, f : A -> M) { } lemma RightIdentity(m : M) { assert Bind(m,Return) == m; } lemma Associativity(m : M, f:A -> M, g: B -> M) { assert Bind(Bind(m,f),g) == Bind(m,x reads f.reads(x) reads g.reads requires f.requires(x) requires forall b :: g.requires(b) => Bind(f(x),g)); } } module Maybe refines Monad { datatype M = Just(A) | Nothing function method Return(x: A): M { Just(x) } function method Bind(m: M, f:A -> M):M { match m case Nothing => Nothing case Just(x) => f(x) } lemma LeftIdentity(x : A, f : A -> M) { } lemma RightIdentity(m : M) { assert Bind(m,Return) == m; } lemma Associativity(m : M, f:A -> M, g: B -> M) { assert Bind(Bind(m,f),g) == Bind(m,x reads f.reads(x) reads g.reads requires f.requires(x) requires forall b :: g.requires(b) => Bind(f(x),g)); } } module List refines Monad { datatype M = Cons(hd: A,tl: M) | Nil function method Return(x: A): M { Cons(x,Nil) } function method Concat(xs: M, ys: M): M { match xs case Nil => ys case Cons(x,xs) => Cons(x,Concat(xs,ys)) } function method Join(xss: M>) : M { match xss case Nil => Nil case Cons(xs,xss) => Concat(xs,Join(xss)) } function method Map(xs: M, f: A -> B):M reads f.reads; requires forall a :: f.requires(a); { match xs case Nil => Nil case Cons(x,xs) => Cons(f(x),Map(xs,f)) } function method Bind(m: M, f:A -> M):M { Join(Map(m,f)) } lemma LeftIdentity(x : A, f : A -> M) { calc { Bind(Return(x),f); == Join(Map(Cons(x,Nil),f)); == Join(Cons(f(x),Nil)); == Concat(f(x),Nil); == { assert forall xs : M :: Concat(xs,Nil) == xs; } f(x); } } lemma RightIdentity(m : M) { match m case Nil => calc { Bind(Nil,Return); == Join(Map(Nil,Return)); == Join(Nil); == Nil; == m; } case Cons(x,xs) => calc { Bind(m,Return); == Bind(Cons(x,xs),Return); == Join(Map(Cons(x,xs),Return)); == Join(Cons(Return(x),Map(xs,Return))); == Concat(Return(x),Join(Map(xs,Return))); == { RightIdentity(xs); } Concat(Return(x),xs); == Concat(Cons(x,Nil),xs); == Cons(x,xs); == m; } } lemma ConcatAssociativity(xs : M, ys : M, zs: M) ensures Concat(Concat(xs,ys),zs) == Concat(xs,Concat(ys,zs)); {} lemma BindMorphism(xs : M, ys: M, f : A -> M) requires forall a :: f.requires(a); ensures Bind(Concat(xs,ys),f) == Concat(Bind(xs,f),Bind(ys,f)); { match xs case Nil => calc { Bind(Concat(Nil,ys),f); == Bind(ys,f); == Concat(Nil,Bind(ys,f)); == Concat(Bind(Nil,f),Bind(ys,f)); } case Cons(z,zs) => calc { Bind(Concat(xs,ys),f); == Bind(Concat(Cons(z,zs),ys),f); == Concat(f(z),Bind(Concat(zs,ys),f)); == { BindMorphism(zs,ys,f); } Concat(f(z),Concat(Bind(zs,f),Bind(ys,f))); == { ConcatAssociativity(f(z),Bind(zs,f),Bind(ys,f)); } Concat(Concat(f(z),Join(Map(zs,f))),Bind(ys,f)); == Concat(Bind(Cons(z,zs),f),Bind(ys,f)); == Concat(Bind(xs,f),Bind(ys,f)); } } lemma Associativity(m : M, f:A -> M, g: B -> M) { match m case Nil => calc { Bind(Bind(m,f),g); == Bind(Bind(Nil,f),g); == Bind(Nil,g); == Nil; == Bind(Nil,x reads f.reads(x) reads g.reads requires f.requires(x) requires forall b :: g.requires(b) => Bind(f(x),g)); == Bind(m,x reads f.reads(x) reads g.reads requires f.requires(x) requires forall b :: g.requires(b) => Bind(f(x),g)); } case Cons(x,xs) => calc { Bind(Bind(m,f),g); == Bind(Bind(Cons(x,xs),f),g); == Bind(Concat(f(x),Bind(xs,f)),g); == { BindMorphism(f(x),Bind(xs,f),g); } Concat(Bind(f(x),g),Bind(Bind(xs,f),g)); == { Associativity(xs,f,g); } Concat(Bind(f(x),g),Join(Map(xs,y reads f.reads(y) reads g.reads requires f.requires(y) requires forall b :: g.requires(b) => Bind(f(y),g)))); == Join(Cons(Bind(f(x),g),Map(xs,y reads f.reads(y) reads g.reads requires f.requires(y) requires forall b :: g.requires(b) => Bind(f(y),g)))); == Join(Map(Cons(x,xs),y reads f.reads(y) reads g.reads requires f.requires(y) requires forall b :: g.requires(b) => Bind(f(y),g))); == Bind(Cons(x,xs),y reads f.reads(y) reads g.reads requires f.requires(y) requires forall b :: g.requires(b) => Bind(f(y),g)); == Bind(m,x reads f.reads(x) reads g.reads requires f.requires(x) requires forall b :: g.requires(b) => Bind(f(x),g)); } } }