From 2ead070c052fb5f506f188571e0e1bef900af9d4 Mon Sep 17 00:00:00 2001 From: Dan Rosén Date: Fri, 15 Aug 2014 11:50:02 -0700 Subject: Add Monads as a module example and implementation of some simple monads --- Test/hofs/Monads.dfy | 244 ++++++++++++++++++++++++++++++++++++++++++++ Test/hofs/Monads.dfy.expect | 2 + 2 files changed, 246 insertions(+) create mode 100644 Test/hofs/Monads.dfy create mode 100644 Test/hofs/Monads.dfy.expect (limited to 'Test/hofs') diff --git a/Test/hofs/Monads.dfy b/Test/hofs/Monads.dfy new file mode 100644 index 00000000..9e7c5460 --- /dev/null +++ b/Test/hofs/Monads.dfy @@ -0,0 +1,244 @@ +// RUN: %dafny /compile:0 "%s" > "%t" +// RUN: %diff "%s.expect" "%t" + +abstract module Monad { + type M; + + static function method Return(x: A): M + static function method Bind(m: M, f:A -> M):M + reads f.reads; + requires forall a :: f.requires(a); + + // return x >>= f = f x + static lemma LeftIdentity(x : A, f : A -> M) + requires forall a :: f.requires(a); + ensures Bind(Return(x),f) == f(x); + + // m >>= return = m + static lemma RightIdentity(m : M) + ensures Bind(m,Return) == m; + + // (m >>= f) >>= g = m >>= (x => f(x) >>= g) + static lemma Associativity(m : M, f:A -> M, g: B -> M) + requires forall a :: f.requires(a); + requires forall b :: g.requires(b); + ensures Bind(Bind(m,f),g) == + Bind(m,x reads f.reads(x) + reads g.reads + requires f.requires(x) + requires forall b :: g.requires(b) => Bind(f(x),g)); +} + +module Identity refines Monad { + datatype M = I(A); + + static function method Return(x: A): M + { I(x) } + + static function method Bind(m: M, f:A -> M):M + { + var I(x) := m; f(x) + } + + static lemma LeftIdentity(x : A, f : A -> M) + { + } + + static lemma RightIdentity(m : M) + { + assert Bind(m,Return) == m; + } + + static lemma Associativity(m : M, f:A -> M, g: B -> M) + { + assert + Bind(Bind(m,f),g) == + Bind(m,x reads f.reads(x) + reads g.reads + requires f.requires(x) + requires forall b :: g.requires(b) => Bind(f(x),g)); + } + +} + +module Maybe refines Monad { + datatype M = Just(A) | Nothing; + + static function method Return(x: A): M + { Just(x) } + + static function method Bind(m: M, f:A -> M):M + { + match m + case Nothing => Nothing + case Just(x) => f(x) + } + + static lemma LeftIdentity(x : A, f : A -> M) + { + } + + static lemma RightIdentity(m : M) + { + assert Bind(m,Return) == m; + } + + static lemma Associativity(m : M, f:A -> M, g: B -> M) + { + assert + Bind(Bind(m,f),g) == + Bind(m,x reads f.reads(x) + reads g.reads + requires f.requires(x) + requires forall b :: g.requires(b) => Bind(f(x),g)); + } + +} + +module List refines Monad { + datatype M = Cons(hd: A,tl: M) | Nil; + + static function method Return(x: A): M + { Cons(x,Nil) } + + static function method Concat(xs: M, ys: M): M + { + match xs + case Nil => ys + case Cons(x,xs) => Cons(x,Concat(xs,ys)) + } + + static function method Join(xss: M>) : M + { + match xss + case Nil => Nil + case Cons(xs,xss) => Concat(xs,Join(xss)) + } + + static function method Map(xs: M, f: A -> B):M + reads f.reads; + requires forall a :: f.requires(a); + { + match xs + case Nil => Nil + case Cons(x,xs) => Cons(f(x),Map(xs,f)) + } + + static function method Bind(m: M, f:A -> M):M + { + Join(Map(m,f)) + } + + static lemma LeftIdentity(x : A, f : A -> M) + { + calc { + Bind(Return(x),f); + == Join(Map(Cons(x,Nil),f)); + == Join(Cons(f(x),Nil)); + == Concat(f(x),Nil); + == { assert forall xs : M :: Concat(xs,Nil) == xs; } + f(x); + } + } + + static lemma RightIdentity(m : M) + { + match m + case Nil => calc { + Bind(Nil,Return); + == Join(Map(Nil,Return)); + == Join(Nil); + == Nil; + == m; + } + case Cons(x,xs) => + calc { + Bind(m,Return); + == Bind(Cons(x,xs),Return); + == Join(Map(Cons(x,xs),Return)); + == Join(Cons(Return(x),Map(xs,Return))); + == Concat(Return(x),Join(Map(xs,Return))); + == { RightIdentity(xs); } + Concat(Return(x),xs); + == Concat(Cons(x,Nil),xs); + == Cons(x,xs); + == m; + } + } + + static lemma ConcatAssociativity(xs : M, ys : M, zs: M) + ensures Concat(Concat(xs,ys),zs) == Concat(xs,Concat(ys,zs)); + {} + + static lemma BindMorphism(xs : M, ys: M, f : A -> M) + requires forall a :: f.requires(a); + ensures Bind(Concat(xs,ys),f) == Concat(Bind(xs,f),Bind(ys,f)); + { + match xs + case Nil => calc { + Bind(Concat(Nil,ys),f); + == Bind(ys,f); + == Concat(Nil,Bind(ys,f)); + == Concat(Bind(Nil,f),Bind(ys,f)); + } + case Cons(z,zs) => calc { + Bind(Concat(xs,ys),f); + == Bind(Concat(Cons(z,zs),ys),f); + == Concat(f(z),Bind(Concat(zs,ys),f)); + == { BindMorphism(zs,ys,f); } + Concat(f(z),Concat(Bind(zs,f),Bind(ys,f))); + == { ConcatAssociativity(f(z),Bind(zs,f),Bind(ys,f)); } + Concat(Concat(f(z),Join(Map(zs,f))),Bind(ys,f)); + == Concat(Bind(Cons(z,zs),f),Bind(ys,f)); + == Concat(Bind(xs,f),Bind(ys,f)); + } + } + + static lemma Associativity(m : M, f:A -> M, g: B -> M) + { + match m + case Nil => calc { + Bind(Bind(m,f),g); + == Bind(Bind(Nil,f),g); + == Bind(Nil,g); + == Nil; + == Bind(Nil,x reads f.reads(x) + reads g.reads + requires f.requires(x) + requires forall b :: g.requires(b) => Bind(f(x),g)); + == Bind(m,x reads f.reads(x) + reads g.reads + requires f.requires(x) + requires forall b :: g.requires(b) => Bind(f(x),g)); + } + case Cons(x,xs) => calc { + Bind(Bind(m,f),g); + == Bind(Bind(Cons(x,xs),f),g); + == Bind(Concat(f(x),Bind(xs,f)),g); + == { BindMorphism(f(x),Bind(xs,f),g); } + Concat(Bind(f(x),g),Bind(Bind(xs,f),g)); + == { Associativity(xs,f,g); } + Concat(Bind(f(x),g),Join(Map(xs,y reads f.reads(y) + reads g.reads + requires f.requires(y) + requires forall b :: g.requires(b) => Bind(f(y),g)))); + == Join(Cons(Bind(f(x),g),Map(xs,y reads f.reads(y) + reads g.reads + requires f.requires(y) + requires forall b :: g.requires(b) => Bind(f(y),g)))); + == Join(Map(Cons(x,xs),y reads f.reads(y) + reads g.reads + requires f.requires(y) + requires forall b :: g.requires(b) => Bind(f(y),g))); + == Bind(Cons(x,xs),y reads f.reads(y) + reads g.reads + requires f.requires(y) + requires forall b :: g.requires(b) => Bind(f(y),g)); + == Bind(m,x reads f.reads(x) + reads g.reads + requires f.requires(x) + requires forall b :: g.requires(b) => Bind(f(x),g)); + } + } +} + diff --git a/Test/hofs/Monads.dfy.expect b/Test/hofs/Monads.dfy.expect new file mode 100644 index 00000000..f5e3b3dc --- /dev/null +++ b/Test/hofs/Monads.dfy.expect @@ -0,0 +1,2 @@ + +Dafny program verifier finished with 36 verified, 0 errors -- cgit v1.2.3