haskellmonadsmonad-transformersquickcheckstate-monad

How to test Monad instance for custom StateT?


I'm learning Monad Transformers, and one of the exercises asks to implement the Monad instance for StateT. I want to test that my implementation admits to the Monad laws using the validity package, which is like the checkers package.

Problem is, my Arbitrary instance doesn't compile. I saw this question, but it doesn't quite do what I want because the test basically duplicates the implementation and doesn't check the laws. There's also this question, but it's unanswered, and I've already figured out how to test Monad Transformers not involving functions (like MaybeT).

{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE InstanceSigs #-}

module Ch11.MonadT (StT (..)) where

import Control.Monad.Trans.State (StateT (..))

newtype StT s m a = StT (s -> m (a, s))
  deriving
    (Functor, Applicative)
    via StateT s m

instance (Monad m) => Monad (StT s m) where
  return :: a -> StT s m a
  return = pure

  (>>=) :: StT s m a -> (a -> StT s m b) -> StT s m b
  StT x >>= f = StT $ \s -> do
    (k, s') <- x s
    let StT y = f k
    y s'

  (>>) :: StT s m a -> StT s m b -> StT s m b
  (>>) = (*>)

My test:

{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeApplications #-}

module Ch11.MonadTSpec (spec) where

import Ch11.MonadT (StT (..))
import Test.Hspec
import Test.QuickCheck
import Test.Validity.Monad

spec :: Spec
spec = do
  monadSpecOnArbitrary @(StTArbit Int [] Int)

-- create wrapper to avoid orphan instance error
newtype StTArbit s m a = StTArbit (StT s m a)
  deriving
    (Functor, Applicative, Monad)

instance (Arbitrary s, Function s, Arbitrary1 m, Arbitrary a) => Arbitrary (StTArbit s m a) where
  arbitrary = do
    f <- arbitrary :: Fun s (m (a, s))
    StTArbit . StT <$> f

Error:

• Couldn't match type: (a0, s0)
                 with: s -> m (a, s)
  Expected: Gen (s -> m (a, s))
    Actual: Gen (a0, s0)
• In the second argument of ‘(<$>)’, namely ‘f’
  In a stmt of a 'do' block: StTArbit . StT <$> f

Solution

  • OP here, this is what I ended up doing.

    -- https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/explicit_forall.html
    {-# LANGUAGE ExplicitForAll #-}
    -- https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/type_applications.html
    {-# LANGUAGE TypeApplications #-}
    
    module Ch11.MonadTSpec (spec) where
    
    import Ch11.MonadT (StT (..), runStT)
    import Data.Function as F
    import Test.Hspec
    import Test.Hspec.QuickCheck
    import Test.QuickCheck
    
    spec :: Spec
    spec = do
      describe "Monad (StT Int [])" $ do
        describe "satisfies Monad laws" $ do
          -- the types are in the same order as in `forall`
          prop "right identity law" (prop_monadRightId @Int @Int @[])
          prop "left identity law" (prop_monadLeftId @Int @Int @Int @[])
          prop "associative law" (prop_monadAssoc @Int @Int @Int @Int @[])
    
    {- HLINT ignore -}
    
    {-
    the types in `forall` are specified in the order of dependency.
    since `m` needs `a` and `s`, those appear before `m` in the list.
    -}
    
    -- (x >>= return) == x
    prop_monadRightId ::
      forall a s m.
      (Monad m, Eq (m (a, s)), Show (m (a, s))) =>
      s ->
      Fun s (m (a, s)) ->
      Property
    prop_monadRightId s f = ((===) `F.on` go) (m >>= return) m
      where
        m = StT $ applyFun f
        go st = runStT st s
    
    -- (return x >>= f) == (f x)
    prop_monadLeftId ::
      forall a b s m.
      (Monad m, Eq (m (b, s)), Show (m (b, s))) =>
      a ->
      s ->
      Fun (a, s) (m (b, s)) ->
      Property
    prop_monadLeftId a s f = ((===) `F.on` go) (return a >>= h) m
      where
        g = applyFun2 f
        m = StT $ g a
        h = StT . g
        go st = runStT st s
    
    -- ((x >>= f) >>= g) == (x >>= (\x' -> f x' >>= g))
    prop_monadAssoc ::
      forall a b c s m.
      (Monad m, Eq (m (b, s)), Show (m (b, s)), Eq (m (c, s)), Show (m (c, s))) =>
      s ->
      Fun s (m (a, s)) ->
      Fun (a, s) (m (b, s)) ->
      Fun (b, s) (m (c, s)) ->
      Property
    prop_monadAssoc s h f g =
      ((===) `F.on` go)
        ((m >>= f') >>= g')
        (m >>= (\x -> f' x >>= g'))
      where
        m = StT $ applyFun h
        f' = StT . applyFun2 f
        g' = StT . applyFun2 g
        go st = runStT st s