scalascalazshapelessstate-monadscalaz7

State transformations with a shapeless State monad


Scalaz State monad's modify has the following signature:

def modify[S](f: S => S): State[S, Unit]

This allows the state to be replaced by state of the same type, which does not work well when the state includes a shapeless value such as a Record whose type changes as new fields are added. In that case what we need is:

def modify[S, T](f: S => T): State[T, Unit]

What is a good way to adapt Scalaz's State monad to use shapeless state so that one can use Records as opposed to, say, the dreaded Map[String, Any]?

Example:

case class S[L <: HList](total: Int, scratch: L)

def contrivedAdd[L <: HList](n: Int): State[S[L], Int] =
  for {
    a <- init
    _ <- modify(s => S(s.total + n, ('latestAddend ->> n) :: s.scratch))
    r <- get
  } yield r.total

Update:

The full code for Travis's answer is here.


Solution

  • State is a type alias for a more generic type IndexedStateT that's specifically designed to represent functions that change the state type as state computations:

    type StateT[F[_], S, A] = IndexedStateT[F, S, S, A]
    type State[S, A] = StateT[Id, S, A]
    

    While it's not possible to write your modify[S, T] using State, it is possible with IndexedState (which is another type alias for IndexedStateT that fixes the effect type to Id):

    import scalaz._, Scalaz._
    
    def transform[S, T](f: S => T): IndexedState[S, T, Unit] =
      IndexedState(s => (f(s), ()))
    

    You can even use this in for-comprehensions (which has always seemed a little odd to me, since the monadic type changes between operations, but it works):

    val s = for {
      a <- init[Int];
      _ <- transform[Int, Double](_.toDouble)
      _ <- transform[Double, String](_.toString)
      r <- get
    } yield r * a
    

    And then:

    scala> s(5)
    res5: scalaz.Id.Id[(String, String)] = (5.0,5.05.05.05.05.0)
    

    In your case you might write something like this:

    import shapeless._, shapeless.labelled.{ FieldType, field }
    
    case class S[L <: HList](total: Int, scratch: L)
    
    def addField[K <: Symbol, A, L <: HList](k: Witness.Aux[K], a: A)(
      f: Int => Int
    ): IndexedState[S[L], S[FieldType[K, A] :: L], Unit] =
      IndexedState(s => (S(f(s.total), field[K](a) :: s.scratch), ()))
    

    And then:

    def contrivedAdd[L <: HList](n: Int) = for {
      a <- init[S[L]]
      _ <- addField('latestAdded, n)(_ + n)
      r <- get
    } yield r.total
    

    (This may not be the best way of factoring out the pieces of the update operation, but it shows how the basic idea works.)

    It's also worth noting that if you don't care about representing the state transformation as a state computation, you can just use imap on any old State:

    init[S[HNil]].imap(s =>
      S(1, field[Witness.`'latestAdded`.T](1) :: s.scratch)
    )
    

    This doesn't allow you to use these operations compositionally in the same way, but it may be all you need in some situations.