newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving (Show,Eq,Functor)
flatten :: Prob (Prob a) -> Prob a
flatten (Prob xs) = Prob $ concat $ map multAll xs
where multAll (Prob innerxs,p) = map (\(x,r) -> (x,p*r)) innerxs
instance Applicative Prob where
liftA2 fn (Prob x) (Prob y) = Prob [(fn a b,prob1 *prob2) |(a,prob1) <- x, (b,prob2) <- y]
pure x = Prob [(x,1%1)]
instance Monad Prob where
m >>= f = flatten (fmap f m)
makedie x = Prob (zip [1..x] (repeat (1%x)))
proboperate2 op a b= sumprobs (liftA2 op a b)
rerolldie op = proboperate2 op <*> id
sumprobs (Prob a) = Prob [(v,sum (map snd (filter ((==v) . fst) a))) |v <- indivalues]
where indivalues = nub (map fst a)
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die)
survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
survivegiven (a,_) _ | a >= 3 = return False
survivegiven (_,a) _ | a >= 3 = return True
survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
survivegiven (a,b) 20 = return True
survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
survivegiven (a,b) n = sumprobs ((survivegiven (1+a,b)) =<< die)
survivedeath starts to take a long time quickly. As far as I can tell, sumprobs
= O(N^2), =<<
= O(N), if survivedeath
goes six times on a d20, it should run 20^2 * 6 or 400 * 6 or 2400 operations, which should happen quickly, why is that not the case?
Although sumprobs
lets you keep your lists short, you are calling survivegiven
for every possible die roll value.
In other words, if die
is a D20,
survivegiven (a,b) n = sumprobs ((survivegiven (1+a,b)) =<< die)
makes 20 recursive calls. So you are still enumerating all dice roll sequences, of which there are only somewhat less than 20^6.
You can easily measure the exact number using Debug.Trace
making every call to survivegiven
print a line, and then counting the lines in the output. Here is your survivedeath
with only one line changed:
import Debug.Trace
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = sumprobs (survivegiven (0,0) =<< die)
survivegiven :: (Integer,Integer) -> Integer -> Prob Bool
survivegiven (a,_) _ | trace "X" $ a >= 3 = return False
survivegiven (_,a) _ | a >= 3 = return True
survivegiven (a,b) 1 = sumprobs ((survivegiven (a+2,b)) =<< die)
survivegiven (a,b) 20 = return True
survivegiven (a,b) n | n >= dc = sumprobs ((survivegiven (a,1+b)) =<< die)
survivegiven (a,b) n = sumprobs ((survivegiven (1+a,b)) =<< die)
You can make your code compilable by adding a main
d20 = makedie 20
main :: IO ()
main = print (survivedeath 6 d20)
$ ghc -O A.hs
$ ./A.hs 2> log # store the stderr output in log
$ wc -l log # count lines in log
8664020 log
There are 8.6 million recursive calls, each working with a rather inefficient representation of probability distributions, so it's expected that this takes at least a few seconds.
A faster solution is to change the purpose of the inner function to compute the state space after n
rolls. You only make one recursive call at a time to get the state space after n-1
rolls. Mix in the next die roll, and wrap the result at every step in sumprobs
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
-- Survival states after (up to) n rolls:
-- - Left True (survived)
-- - Left False (died)
-- - Right (a, b) ('a' failed throws, 'b' successfull throws)
surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
surviveafter 0 = pure (Right (0, 0))
surviveafter n = sumprobs (do
state <- surviveafter (n-1)
case state of
Left _ -> pure state
Right (a, b) -> do
roll <- die
case roll of
1 -> pure (step (a+2, b))
20 -> pure (Left True)
n | n >= dc -> pure (step (a, 1+b))
| otherwise -> pure (step (1+a, b)))
step (a, b) | a >= 3 = Left False
| b >= 3 = Left True
| otherwise = Right (a, b)
Compilable file:
survivedeath :: Integer -> Prob Integer -> Prob Bool
survivedeath dc die = (\(Left b) -> b) <$> surviveafter 6
-- Survival state after (up to) n rolls:
-- - Left True (survived)
-- - Left False (died)
-- - Right (a, b) ('a' failed throws, 'b' successfull throws)
surviveafter :: Integer -> Prob (Either Bool (Integer, Integer))
surviveafter 0 = pure (Right (0, 0))
surviveafter n = sumprobs (do
state <- surviveafter (n-1)
case state of
Left _ -> pure state
Right (a, b) -> do
roll <- die
case roll of
1 -> pure (step (a+2, b))
20 -> pure (Left True)
n | n >= dc -> pure (step (a, 1+b))
| otherwise -> pure (step (1+a, b)))
step (a, b) | a >= 3 = Left False
| b >= 3 = Left True
| otherwise = Right (a, b)
d20 = makedie 20
main :: IO ()
main = print (survivedeath 6 d20)