haskellconcurrencyfunctional-programmingstmtransactional-memory

In the context of STM, what does a transaction log conceptually look like, and how does it evolve when the transaction succeeds after a few retries?


For instance consider this function, that could be used in a WM to allow moving a window from one desktop to another on a given display,

moveWindowSTM :: Display -> Window -> Desktop -> Desktop -> STM ()
moveWindowSTM disp win a b = do
  wa <- readTVar ma
  wb <- readTVar mb
  writeTVar ma (Set.delete win wa)
  writeTVar mb (Set.insert win wb)
 where
  ma = disp ! a
  mb = disp ! b

and obviously its IO wrapper,

moveWindow :: Display -> Window -> Desktop -> Desktop -> IO ()
moveWindow disp win a b = atomically $ moveWindowSTM disp win a b

and then assume that

How would the transaction log evolve in this case, and at which point would the validation fail?

The relevant excerpt from Parallel and Concurrent Programming in Haskell by Simon Marlow is below (but similar information is available in the paper Beautiful concurrency by Simon Peyton Jones):

An STM transaction works by accumulating a log of readTVar and writeTVar operations that have happened so far during the transaction. The log is used in three ways:

  • By storing writeTVar operations in the log rather than applying them to main memory immediately, discarding the effects of a transaction is easy; we just throw away the log. Hence, aborting a transaction has a fixed small cost.

  • Each readTVar must traverse the log to check whether the TVar was written by an earlier writeTVar. Hence, readTVar is an O(n) operation in the length of the log.

  • Because the log contains a record of all the readTVar operations, it can be used to discover the full set of TVars read during the transaction, which we need to know in order to implement retry.

When a transaction reaches the end, the STM implementation compares the log against the contents of memory. If the current contents of memory match the values read by readTVar, the effects of the transaction are committed to memory, and if not, the log is discarded and the transaction runs again from the beginning. This process takes place atomically by locking all the TVars involved in the transaction for the duration. The STM implementation in GHC does not use global locks; only the TVars involved in the transaction are locked during commit, so transactions operating on disjoint sets of TVars can proceed without interference.


Solution

  • I think something like this should serve as a pretty decent mental model:

    data Generational a = Generational
        { generation :: Int
        , value :: a
        }
    instance Eq (Generational a) where (==) = (==) `on` generation
    
    -- choose a generation number that's never been chosen before
    -- e.g. by starting at minBound and incrementing one at each call
    freshGeneration :: IO Int
    
    newtype TVar a = TVar (IORef (Generational a))
    
    unsafeReadTVar :: MonadIO m => TVar a -> m (Generational a)
    unsafeReadTVar (TVar ref) = liftIO $ readIORef ref
    
    unsafeWriteTVar :: Int -> TVar a -> Identity a -> IO ()
    unsafeWriteTVar g (TVar ref) (Identity a) = writeIORef ref (Generational g a)
    
    lock, unlock :: TVar a -> IO ()
    
    -- like Map, but with an extra type argument
    data Map1 (key :: k -> *) (value :: k -> *)
    
    type STMLog = Map1 TVar Generational
    
    newtype STM a = STM (StateT STMLog IO a)
        deriving (Functor, Applicative, Monad)
    
    updateLog :: TVar a -> STM (Generational a)
    updateLog v = STM do
        log <- get
        case lookup1 v log of
            Just ga -> pure ga
            Nothing -> do
                ga <- unsafeReadTVar v
                ga <$ put (insert1 v ga log)
    
    readTVar :: TVar a -> STM a
    readTVar v = value <$> updateLog v
    
    writeTVar :: TVar a -> a -> STM ()
    writeTVar v a = do
        ga <- updateLog v
        STM . put $ insert1 v ga { value = a } log
    
    atomically :: STM a -> IO a
    atomically (STM act) = do
        (a, log) <- runStateT act empty1
        let allVars = keys1 log
        traverse1_ lock allVars
        log' <- traverseWithKey1 (\v _ -> unsafeReadTVar v) log
        -- N.B. (==) only compares generations here
        -- (the real implementation probably uses some form
        -- of pointer equality instead of storing generations)
        if log == log'
            then do
                g <- freshGeneration
                traverseWithKey1_ (unsafeWriteTVar g) log
                a <$ traverse1_ unlock allVars
            else do
                traverse1_ unlock allVars
                atomically (STM act) -- try again
    

    The real thing is probably somewhat more complicated/optimized/robust. I'm not bothering to deal with exceptions at all here, I'm assuming locking a collection of variables all at once is easy, I'm not showing the implementation of retry, I haven't thought about whether any memory fencing is needed, etc.

    Notice that even writes update which generation is expected in atomically (as pointed out by K. A. Buhr). That surprised me!