haskelltime-complexitylazy-evaluation

Why does this haskell program have incorrect time complexity?


newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
    where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Applicative Prob where 
    liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
    pure x = Prob [(x,1%1)]
instance Monad Prob where
    m >>= f = flatten (fmap f m)
makedie x = Prob (zip [1..x] (repeat (1%x)))
proboperate2 op a b= sumprobs (liftA2 op a b)
rerolldie op = proboperate2 op <*> id
sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
    where indivalues = nub (map fst a)
survivedeath :: Integer -> Prob Integer -> Prob Bool 
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die) 
    where 
        survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
        survivegiven (a,_) _ | a >= 3 = return False 
        survivegiven (_,a) _ | a >= 3 = return True 
        survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
        survivegiven (a,b) 20 = return True 
        survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
        survivegiven (a,b) n  = sumprobs ((survivegiven (1+a,b)) =<< die)

survivedeath starts to take a long time quickly. As far as I can tell, sumprobs = O(N^2), =<< = O(N), if survivedeath goes six times on a d20, it should run 20^2 * 6 or 400 * 6 or 2400 operations, which should happen quickly, why is that not the case?


Solution

  • Although sumprobs lets you keep your lists short, you are calling survivegiven for every possible die roll value.

    In other words, if die is a D20,

    survivegiven (a,b) n  = sumprobs ((survivegiven (1+a,b)) =<< die)
    

    makes 20 recursive calls. So you are still enumerating all dice roll sequences, of which there are only somewhat less than 20^6.

    You can easily measure the exact number using Debug.Trace making every call to survivegiven print a line, and then counting the lines in the output. Here is your survivedeath with only one line changed:

    import Debug.Trace
    
    ...
    
    survivedeath :: Integer -> Prob Integer -> Prob Bool 
    survivedeath dc die = sumprobs (survivegiven (0,0) =<< die) 
        where 
            survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
            survivegiven (a,_) _ | trace "X" $ a >= 3 = return False 
            survivegiven (_,a) _ | a >= 3 = return True 
            survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
            survivegiven (a,b) 20 = return True 
            survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
            survivegiven (a,b) n  = sumprobs ((survivegiven (1+a,b)) =<< die)
    

    You can make your code compilable by adding a main function:

    d20 = makedie 20
    
    main :: IO ()
    main = print (survivedeath 6 d20)
    
    $ ghc -O A.hs
    $ ./A.hs 2> log   # store the stderr output in log
    $ wc -l log       # count lines in log
    8664020 log
    

    There are 8.6 million recursive calls, each working with a rather inefficient representation of probability distributions, so it's expected that this takes at least a few seconds.


    A faster solution is to change the purpose of the inner function to compute the state space after n rolls. You only make one recursive call at a time to get the state space after n-1 rolls. Mix in the next die roll, and wrap the result at every step in sumprobs.

    survivedeath :: Integer -> Prob Integer -> Prob Bool
    survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
        where 
            -- Survival states after (up to) n rolls:
            -- - Left True (survived)
            -- - Left False (died)
            -- - Right (a, b) ('a' failed throws, 'b' successfull throws)
            surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
            surviveafter 0 = pure (Right (0, 0))
            surviveafter n = sumprobs (do
              state <- surviveafter (n-1)
              case state of
                Left _ -> pure state
                Right (a, b) -> do
                  roll <- die
                  case roll of
                    1 -> pure (step (a+2, b))
                    20 -> pure (Left True)
                    n | n >= dc -> pure (step (a, 1+b))
                      | otherwise -> pure (step (1+a, b)))
            step (a, b) | a >= 3 = Left False
                        | b >= 3 = Left True
                        | otherwise = Right (a, b)
    

    Compilable file:

    import Data.Ratio ((%))
    import Data.List (nub)
    
    newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
    flatten :: Prob (Prob a) -> Prob a
    flatten (Prob xs) = Prob $ concat $ map multAll xs
        where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
    instance Applicative Prob where 
        liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
        pure x = Prob [(x,1%1)]
    instance Monad Prob where
        m >>= f = flatten (fmap f m)
    makedie x = Prob (zip [1..x] (repeat (1%x)))
    proboperate2 op a b= sumprobs (liftA2 op a b)
    rerolldie op = proboperate2 op <*> id
    sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
        where indivalues = nub (map fst a)
    survivedeath :: Integer -> Prob Integer -> Prob Bool
    survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
        where 
            -- Survival state after (up to) n rolls:
            -- - Left True (survived)
            -- - Left False (died)
            -- - Right (a, b) ('a' failed throws, 'b' successfull throws)
            surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
            surviveafter 0 = pure (Right (0, 0))
            surviveafter n = sumprobs (do
              state <- surviveafter (n-1)
              case state of
                Left _ -> pure state
                Right (a, b) -> do
                  roll <- die
                  case roll of
                    1 -> pure (step (a+2, b))
                    20 -> pure (Left True)
                    n | n >= dc -> pure (step (a, 1+b))
                      | otherwise -> pure (step (1+a, b)))
            step (a, b) | a >= 3 = Left False
                        | b >= 3 = Left True
                        | otherwise = Right (a, b)
    d20 = makedie 20
    
    main :: IO ()
    main = print (survivedeath 6 d20)