haskellrandomtransactionstransactional-memory

Adding random number generation to the STM monad in Haskell


I am currently working on some transactional memory benchmarks in Haskell and would like to be able to make use of random numbers within a transaction. I am currently using the Random monad/monad transformer from here. In the following example, I have an array of TVars containing integers and a transaction that randomly selects 10 tvars in the array to increment, such as:

tvars :: STM (TArray Int Int)
tvars = newArray (0, numTVars) 0

write :: Int -> RandT StdGen STM [Int]
write 0 = return []
write i = do
    tvars <- lift tvars
    rn <- getRandomR (0, numTVars)
    temp <- lift $ readArray tvars rn
    lift $ writeArray tvars rn (temp + 1)
    rands <- write (i-1)
    lift $ return $ rn : rands

I guess my question is "Is this the best way to go about doing this?" It seems like it would be more natural/efficient to go the other way around, i.e. lift the random monad into the STM monad. Each transaction does a lot of STM operations, and very few random operations. I would assume that each lift adds some amount of overhead. Wouldn't it be more efficient to only lift the random computations and leave the STM computations alone? Is this even safe to do? It seems that defining an STM monad transformer would break the nice static separation properties that we get with the STM monad (i.e. We could lift IO into the STM monad, but then we have to worry about undoing IO actions if a transaction aborts which presents a number of issues). My knowledge of monad transformers is pretty limited. A brief explanation regarding the performance and relative overhead of using a transformer would be much appreciated.


Solution

  • STM is a base monad, think what atomically, which is currently STM a -> IO a should look like, if we had STMT.

    I have few solutions to your particular problem in mind. Simpler one is probably to re-arrange the code:

    write :: Int -> RandT StdGen STM [Int]
    write n = do
       -- random list of indexes, so you don't need to interleave random and stm code at all
       rn <- getRandomRs (0, numTVars) 
       lift $ go rn
       where go []     = return []
             go (i:is) = do tvars <- tvars -- this is redundant, could be taken out of the loop
                            temp <-  readArray tvars i
                            writeArray tvars i (temp + 1)
                            rands <- go is
                            return $ i : rands
    

    Yet the RandT is essentially StateT with lift:

    instance MonadTrans (StateT s) where
        lift m = StateT $ \ s -> do
            a <- m
            return (a, s)
    

    So the code of form:

    do x <- lift baseAction1
       y <- lift baseAction2
       return $ f x y
    

    Will be

    do x <- StateT $ \s -> do { a <- baseAction1; return (a, s) }
       y <- StateT $ \s -> do { a <- baseAction2; return (a, s) }
       return $ f x y
    

    which is after desugaring do notation

    StateT (\s -> do { a <- baseAction1; return (a, s) }) >>= \ x ->
    StateT (\s -> do { a <- baseAction2; return (a, s) }) >>= \ y ->
    return $ f x y
    

    inlining first >>=

    StateT $ \s -> do
      ~(a, s') <- runStateT (StateT (\s -> do { a <- baseAction1; return (a, s) })) s
      runStateT ((\ x -> StateT (\s -> do { a <- baseAction2; return (a, s) }) >>= \ y -> return $ f x y) a) s'
    

    StateT and runStateT cancel out:

    StateT $ \s -> do
      ~(x, s') <- do { a <- baseAction1; return (a, s) }))
      runStateT ((\ x -> StateT (\s -> do { a <- baseAction2; return (a, s) }) >>= \ y -> return $ f x y) x) s'
    

    And after few inlining / reduction steps:

    StateT $ \s -> do
      ~(x, s') <- do { a <- baseAction1; return (a, s) }))
      ~(y, s'') <- do { a <- baseAction2; return (a, s') }))
      return (f x y, s'')
    

    Probably GHC is smart enough to reduce this is even further, so state is just passed thru without creating intermediate pairs (yet I'm not sure, one should use monad laws to justify that):

    StateT $ \s -> do
       x <- baseAction1
       y <- baseAction2
       return (f x y, s)
    

    which is what you get from

    lift do x <- baseAction1
            y <- baseAction2
            return $ f x y