One of my colleagues (Roland Zumkeller) posted some nifty functions to count the number of expressions in an AST for the DSL we work on. This led to an email and chat discussion that I have summarised in this post. Any errors are entirely mine.

Let’s start off with our target language. It’s easy to write an interpreeter and also a function to count nodes in it.

> {-# LANGUAGE
>     DeriveFunctor,
>     DeriveFoldable,
>     DeriveTraversable,
>     RankNTypes,
>     FlexibleContexts,
>     NoMonomorphismRestriction,
>     UndecidableInstances #-}
>
> import Prelude hiding (mapM)
> import Data.Foldable
> import Data.Traversable
> import Data.Monoid
> import qualified Data.Map as Map
>
> data Term = Plus Term Term
>           | Mult Term Term
>           | IntConst Int
>           deriving Show
>
> interp :: Term -> Int
> interp (Plus x y)   = interp x + interp y
> interp (Mult x y)   = interp x * interp y
> interp (IntConst x) = x
>
> numTerms :: Term -> Int
> numTerms (Plus x y)   = numTerms x + numTerms y
> numTerms (Mult x y)   = numTerms x + numTerms y
> numTerms (IntConst _) = 1


Obviously there is a pattern here and we can abstract it by using a catamorphism.

> data TermF a = PlusF a a
>              | MultF a a
>              | IntConstF Int
>              deriving (Show, Ord, Eq, Functor, Foldable, Traversable)
>
> newtype Mu f = In {in_ :: f (Mu f)}
>
> instance Eq  (f (Mu f)) => Eq  (Mu f) where
>   In x == In y = x == y
> instance Ord (f (Mu f)) => Ord (Mu f) where
>   In x compare In y = x compare y
>
> type Term' = Mu TermF
>
> type Algebra f a = f a -> a
>
> cata :: Functor f => Algebra f a -> (Mu f -> a)
> cata f (In x) = f (fmap (cata f) x)
>
> countAlgebra :: Num a => Algebra TermF (Sum a)
> countAlgebra = \t -> fold t mappend (Sum 1)
>
> countTerms = getSum . cata countAlgebra
>
> testTerm :: Term'
> testTerm = let x = (In $MultF (In$ IntConstF 2) (In $IntConstF 3)) > y = (In$ MultF (In $IntConstF 5) (In$ IntConstF 7))
>            in In $PlusF x y  But perhaps we would like a monadic catamorphism; in other words we would like to work with algebras in the Kleisli category (of the given monad). Perhaps Mendler style folds will help. We can capture the naturality condition in the definition of a Mendler algebra by using a rank n type. > type MendlerAlgebra f c = forall a. (a -> c) -> f a -> c > > mcata :: MendlerAlgebra f c -> (Mu f -> c) > mcata phi = phi (mcata phi) . in_  Ordinary folds and Mendler style folds are equivalent. One of the advantages of using Mendler style folds is that they can be generalzed to adjoint folds. > toMCata :: Functor f => Algebra f a -> MendlerAlgebra f a > toMCata phi = \gamma -> phi . fmap gamma > > fromMCata :: MendlerAlgebra f a -> Algebra f a > fromMCata phi = phi id  And now I can create a "monadic" algebra over which I can fold (Mendler style). > countTermsMAlgebra :: (MonadWriter (Sum t) m, Num t) => > MendlerAlgebra TermF (m ()) > countTermsMAlgebra h (PlusF x y) = h x >> h y >> tell (Sum 1) > countTermsMAlgebra h (MultF x y) = h x >> h y >> tell (Sum 1) > countTermsMAlgebra h (IntConstF _) = tell (Sum 1) > > countTerms' :: Num a => Term' -> a > countTerms' t = getSum$ execWriter $mcata countTermsMAlgebra t  Sadly we are still not working on an algebra in the Kleisli category which is obvious if we use the isomorphism from Mendler style algebras. > countTermsAlgebra :: (MonadWriter (Sum t) m, Num t) => > Algebra TermF (m ()) > countTermsAlgebra = fromMCata countTermsMAlgebra  However if you make some more assumptions about the algebra functor (that it is Traversable) then one can do this. It’s not clear (to me at least) what this requirement is in categorical terms. > type MonadicAlgebra f m a = f a -> m a > > cataM :: (Traversable f, Monad m) => > MonadicAlgebra f m a -> (Mu f -> m a) > cataM f = f <=< mapM (cataM f) . in_  One can go further and memoise monadic catamorphisms. > type CacheT a b = StateT (Map.Map a b) > > memoise :: (Monad m, Ord a) => > ((a -> CacheT a b m b) -> (a -> CacheT a b m b)) -> > (a -> CacheT a b m b) > memoise f x = gets (Map.lookup x) >>= (maybe return) > (do y <- f (memoise f) x; modify (Map.insert x y); return y) > > cataMemoM :: (Traversable f, Monad m, Ord (Mu f)) => > MonadicAlgebra f m a -> (Mu f -> m a) > cataMemoM f = (evalStateT Map.empty) > . memoise (\cataM_f -> lift . f <=< mapM cataM_f . in_)  And now we can count unique nodes in our DSL. > countTermsMonadicAlgebra :: MonadicAlgebra TermF (Writer (Sum Integer)) () > countTermsMonadicAlgebra = const$ tell $Sum 1 > > countTermsSharedM :: Mu TermF -> Integer > countTermsSharedM = getSum . > execWriter . > cataMemoM countTermsMonadicAlgebra  You need caching to avoid counting the shared terms. The point is that the monadic action (tell) should not be repeated when the same term is encountered again. Here’s an example of our program counting the unique nodes in an expression in our language. > t = let x = In$ MultF (In $IntConstF 2) (In$ IntConstF 3)
>     in In \$ PlusF x x

*Main> countTermsSharedM t
4
*Main> countTerms t
7