Monadic Caching Folds

One of my colleagues (Roland Zumkeller) posted some nifty functions to count the number of expressions in an AST for the DSL we work on. This led to an email and chat discussion that I have summarised in this post. Any errors are entirely mine.

Let’s start off with our target language. It’s easy to write an interpreeter and also a function to count nodes in it.

> {-# LANGUAGE
>     DeriveFunctor,
>     DeriveFoldable,
>     DeriveTraversable,
>     RankNTypes,
>     FlexibleContexts,
>     NoMonomorphismRestriction,
>     UndecidableInstances #-}
> 
> import Prelude hiding (mapM)
> import Data.Foldable
> import Data.Traversable
> import Data.Monoid
> import Control.Monad.Writer hiding (mapM)
> import Control.Monad.State hiding (mapM)
> import qualified Data.Map as Map
> 
> data Term = Plus Term Term
>           | Mult Term Term
>           | IntConst Int
>           deriving Show
> 
> interp :: Term -> Int
> interp (Plus x y)   = interp x + interp y
> interp (Mult x y)   = interp x * interp y
> interp (IntConst x) = x
> 
> numTerms :: Term -> Int
> numTerms (Plus x y)   = numTerms x + numTerms y
> numTerms (Mult x y)   = numTerms x + numTerms y
> numTerms (IntConst _) = 1

Obviously there is a pattern here and we can abstract it by using a catamorphism.

> data TermF a = PlusF a a
>              | MultF a a
>              | IntConstF Int
>              deriving (Show, Ord, Eq, Functor, Foldable, Traversable)
> 
> newtype Mu f = In {in_ :: f (Mu f)}
> 
> instance Eq  (f (Mu f)) => Eq  (Mu f) where
>   In x == In y = x == y
> instance Ord (f (Mu f)) => Ord (Mu f) where
>   In x `compare` In y = x `compare` y
> 
> type Term' = Mu TermF
> 
> type Algebra f a = f a -> a
> 
> cata :: Functor f => Algebra f a -> (Mu f -> a)
> cata f (In x) = f (fmap (cata f) x)
> 
> countAlgebra :: Num a => Algebra TermF (Sum a)
> countAlgebra = \t -> fold t `mappend` (Sum 1)
> 
> countTerms = getSum . cata countAlgebra
> 
> testTerm :: Term'
> testTerm = let x = (In $ MultF (In $ IntConstF 2) (In $ IntConstF 3)) 
>                y = (In $ MultF (In $ IntConstF 5) (In $ IntConstF 7)) 
>            in In $ PlusF x y

But perhaps we would like a monadic catamorphism; in other words we would like to work with algebras in the Kleisli category (of the given monad).

Perhaps Mendler style folds will help. We can capture the naturality condition in the definition of a Mendler algebra by using a rank n type.

> type MendlerAlgebra f c = forall a. (a -> c) -> f a -> c
> 
> mcata :: MendlerAlgebra f c -> (Mu f -> c)
> mcata phi = phi (mcata phi) . in_

Ordinary folds and Mendler style folds are equivalent. One of the advantages of using Mendler style folds is that they can be generalzed to adjoint folds.

> toMCata :: Functor f => Algebra f a -> MendlerAlgebra f a
> toMCata phi = \gamma -> phi . fmap gamma
> 
> fromMCata :: MendlerAlgebra f a -> Algebra f a
> fromMCata phi = phi id

And now I can create a "monadic" algebra over which I can fold (Mendler style).

> countTermsMAlgebra :: (MonadWriter (Sum t) m, Num t) =>
>                       MendlerAlgebra TermF (m ())
> countTermsMAlgebra h (PlusF x y)   = h x >> h y >> tell (Sum 1)
> countTermsMAlgebra h (MultF x y)   = h x >> h y >> tell (Sum 1)
> countTermsMAlgebra h (IntConstF _) = tell (Sum 1)
> 
> countTerms' :: Num a => Term' -> a
> countTerms' t = getSum $ execWriter $ mcata countTermsMAlgebra t

Sadly we are still not working on an algebra in the Kleisli category which is obvious if we use the isomorphism from Mendler style algebras.

> countTermsAlgebra ::  (MonadWriter (Sum t) m, Num t) =>
>                       Algebra TermF (m ())
> countTermsAlgebra = fromMCata countTermsMAlgebra

However if you make some more assumptions about the algebra functor (that it is Traversable) then one can do this. It’s not clear (to me at least) what this requirement is in categorical terms.

> type MonadicAlgebra f m a = f a -> m a
> 
> cataM :: (Traversable f, Monad m) =>
>          MonadicAlgebra f m a -> (Mu f -> m a)
> cataM f = f <=< mapM (cataM f) . in_

One can go further and memoise monadic catamorphisms.

> type CacheT a b = StateT (Map.Map a b)
> 
> memoise :: (Monad m, Ord a) =>
>            ((a -> CacheT a b m b) -> (a -> CacheT a b m b)) ->
>            (a -> CacheT a b m b)
> memoise f x = gets (Map.lookup x) >>= (`maybe` return)
>   (do y <- f (memoise f) x; modify (Map.insert x y); return y)
> 
> cataMemoM :: (Traversable f, Monad m, Ord (Mu f)) =>
>              MonadicAlgebra f m a -> (Mu f -> m a)
> cataMemoM f = (`evalStateT` Map.empty)
>   . memoise (\cataM_f -> lift . f <=< mapM cataM_f . in_)

And now we can count unique nodes in our DSL.

> countTermsMonadicAlgebra :: MonadicAlgebra TermF (Writer (Sum Integer)) ()
> countTermsMonadicAlgebra = const $ tell $ Sum 1
> 
> countTermsSharedM :: Mu TermF -> Integer
> countTermsSharedM = getSum .
>                     execWriter .
>                     cataMemoM countTermsMonadicAlgebra

You need caching to avoid counting the shared terms. The point is that the monadic action (tell) should not be repeated when the same term is encountered again.

Here’s an example of our program counting the unique nodes in an expression in our language.

> t = let x = In $ MultF (In $ IntConstF 2) (In $ IntConstF 3)
>     in In $ PlusF x x
*Main> countTermsSharedM t
4
*Main> countTerms t
7
Advertisements

One thought on “Monadic Caching Folds

  1. Pingback: Catamorphisms are Post Order, Anamorphisms are Pre Order « Idontgetoutmuch’s Weblog

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s