Folding Trees in PureScript

Posted on June 10, 2019 by Riccardo

Let’s say we wanted to perform two operations on a tree:

  • count the number of leaves
  • transform it to a list

In this post we will perform both by employing three different strategies:

  • recursive functions
  • using the Foldable typeclass
  • using the State Monad

The Tree Type

Recursive Functions

PureScript is a purely functional programming language and Tree a is a recursive type: recursive functions are a perfect fit.

The functions do what they are supposed to do. However, their shape is really similar. The only differences between countTreeRec and toListRec are:

  • the initial value passed to the go function (i.e. 0 vs Nil)
  • the calculation in the base case of go (i.e. i + 1 vs xs <> Cons x Nil)
  • the way the recursive case combines the result of the recursive calls (i.e. + vs <>)

What’s described above is exactly what the Foldable typeclass captures. Let’s see how that looks in code.

Using the Foldable Typeclass

The Foldable typeclass captures the idea of “folding” a structure into another one.

In this case, we could have used foldr or foldl to achieve the same results. But foldMap is a tad more elegant. The way it works is simple:

  • It first runs each element of the tree through the function passed to it (i.e. (\_ -> Additive 1) vs (\x -> Cons x Nil). That function must transform each element of the tree into a Monoid
  • It combines all of the Monoids of the tree using the <> operator. Since <> is implemented as + for Additive and <> is implemented as Cons for List, everything works as before.

Try to compare countTreeFold vs countTreeRec and toListFold vs toListRec.

Using the State Monad

The foldable trick is totally cool. But why not go overkill implementing and using a State Monad?

newtype State s a = State (s -> Tuple a s)

runState :: forall s a. State s a -> s -> Tuple a s
runState (State s) a = s a

instance functorState :: Functor (State s) where
    -- map :: forall a b. (a -> b) -> f a -> f b
    map g f = State (\s -> let Tuple a s' = runState f s in Tuple (g a) s')

instance applyState :: Functor (State s) => Apply (State s) where
    -- apply :: forall a b. f (a -> b) -> f a -> f b
    apply fg f = State (\s -> let Tuple g s'  = runState fg s
                                  Tuple a s'' = runState f s' in Tuple (g a) s'')

instance applicativeState :: Apply (State s) => Applicative (State s) where
    -- pure :: forall a. a -> f a
    pure a = State (\s -> Tuple a s)

instance bindState :: Apply (State s) => Bind (State s) where
    -- bind :: forall a b. m a -> (a -> m b) -> m b
    bind m mg = State (\s -> let Tuple a s' = runState m s in runState (mg a) s')

addOne :: State Int Int
addOne = State (\s -> Tuple s (s+1))

countTreeState :: forall a. Tree a -> State Int (Tree Int)
countTreeState (Leaf _)   = Leaf <$> addOne
countTreeState (Node l r) = Node <$> countTreeState l <*> countTreeState r

appendValue :: forall a. a -> State (List a) a
appendValue x = State (\s -> Tuple x (s <> Cons x Nil))

toListState :: forall a. Tree a -> State (List a) (Tree a)
toListState (Leaf x)   = Leaf <$> appendValue x
toListState (Node l r) = Node <$> toListState l <*> toListState r

main :: Effect Unit
main = do
  logShow $ snd $ runState (countTreeState exampleTree) 0
  -- 3
  logShow $ snd $ runState (toListState exampleTree) Nil
  -- ('a' : 'b' : 'c' : Nil)

I’m gonna cover State in a future post, so keep tuned!

The Whole Code

module Main where

import Prelude (class Applicative, class Apply, class Bind, class Functor, class Show, Unit, discard, show, ($), (+), (<$>),
 (<*>), (<>))
import Effect (Effect)
import Effect.Console (logShow)
import Data.Foldable
import Data.List (List(..), foldMap)
import Data.Monoid.Additive (Additive(..))
import Data.Tuple (Tuple(..), snd)

data Tree a
    = Leaf a
    | Node (Tree a) (Tree a)

instance showTree :: Show a => Show (Tree a) where
    show (Leaf x)   = "(Leaf " <> show x <> ")"
    show (Node l r) = "(Node " <> show l <> " " <> show r <> ")"

exampleTree :: Tree Char
exampleTree =
    Node
      (Node (Leaf 'a') (Leaf 'b'))
      (Leaf 'c')

countTreeRec :: forall a. Tree a -> Int
countTreeRec tree =
    go 0 tree
    where
          go i (Leaf _)   = i + 1
          go i (Node l r) = go i l + go i r

toListRec :: forall a. Tree a -> List a
toListRec tree =
    go Nil tree
    where
          go xs (Leaf x)   = xs <> Cons x Nil
          go xs (Node l r) = go xs l <> go xs r

instance foldableTree :: Foldable Tree where
    -- foldMap :: forall a m. Monoid m => (a -> m) -> f a -> m
    foldMap g (Leaf x)   = g x
    foldMap g (Node l r) = foldMap g l <> foldMap g r

    foldr g = foldrDefault g
    foldl g = foldlDefault g

countTreeFold :: forall a. Tree a -> Int
countTreeFold tree =
    count
    where Additive count = foldMap (\_ -> Additive 1) tree

toListFold :: forall a. Tree a -> List a
toListFold tree =
    foldMap (\x -> Cons x Nil) tree

newtype State s a = State (s -> Tuple a s)

runState :: forall s a. State s a -> s -> Tuple a s
runState (State s) a = s a

instance functorState :: Functor (State s) where
    -- map :: forall a b. (a -> b) -> f a -> f b
    map g f = State (\s -> let Tuple a s' = runState f s in Tuple (g a) s')

instance applyState :: Functor (State s) => Apply (State s) where
    -- apply :: forall a b. f (a -> b) -> f a -> f b
    apply fg f = State (\s -> let Tuple g s'  = runState fg s
                                  Tuple a s'' = runState f s' in Tuple (g a) s'')

instance applicativeState :: Apply (State s) => Applicative (State s) where
    -- pure :: forall a. a -> f a
    pure a = State (\s -> Tuple a s)

instance bindState :: Apply (State s) => Bind (State s) where
    -- bind :: forall a b. m a -> (a -> m b) -> m b
    bind m mg = State (\s -> let Tuple a s' = runState m s in runState (mg a) s')

addOne :: State Int Int
addOne = State (\s -> Tuple s (s+1))

countTreeState :: forall a. Tree a -> State Int (Tree Int)
countTreeState (Leaf _)   = Leaf <$> addOne
countTreeState (Node l r) = Node <$> countTreeState l <*> countTreeState r

appendValue :: forall a. a -> State (List a) a
appendValue x = State (\s -> Tuple x (s <> Cons x Nil))

toListState :: forall a. Tree a -> State (List a) (Tree a)
toListState (Leaf x)   = Leaf <$> appendValue x
toListState (Node l r) = Node <$> toListState l <*> toListState r

main :: Effect Unit
main = do
  logShow exampleTree
  -- (Node (Node (Leaf 'a') (Leaf 'b')) (Leaf 'c'))
  logShow $ countTreeRec exampleTree
  -- 3
  logShow $ toListRec exampleTree
  -- ('a' : 'b' : 'c' : Nil)
  logShow $ countTreeFold exampleTree
  -- 3
  logShow $ toListFold exampleTree
  -- ('a' : 'b' : 'c' : Nil)
  logShow $ snd $ runState (countTreeState exampleTree) 0
  -- 3
  logShow $ snd $ runState (toListState exampleTree) Nil
  -- ('a' : 'b' : 'c' : Nil)