monadstail-recursionscala-catsmonixtagless-final

Scala Cats: tail recursive tailRecM method for Monad Instance of Task[Validated[String,?]


In cats, when a Monad is created using the Monad trait, ideally a tail-recursive implementation for method tailRecM should be provided to ensure stack safety.

I am using the tagless final approach and wish to have an effect of Task[Validated[String, ?]] (Monix Task) for my program.

I can't work out how to write a tail-recursive implementation. My non-tail-recursive solution is:

import cats.Monad
import cats.data.Validated
import cats.data.Validated.{Invalid, Valid}
import monix.eval.Task

final case class TaskValidated[A](value: Task[Validated[String, A]])

implicit val taskValidatedMonad: Monad[TaskValidated] = 
    new Monad[TaskValidated] {

        override def flatMap[A, B](fa: TaskValidated[A])(f: A => TaskValidated[B]): TaskValidated[B] =  
            new TaskValidated[B](   
                fa.value.flatMap {  
                    case Valid(a)   => f(a).value   
                    case Invalid(s) => Task(Invalid(s)) 
                }   
            )

        override def pure[A](a: A): TaskValidated[A] = TaskValidated(Task(Valid(a)))

        // @annotation.tailrec  
        def tailRecM[A, B](init: A)(fn: A => TaskValidated[Either[A, B]]): TaskValidated[B] = { 
            TaskValidated(fn(init).value.flatMap {  
                case Invalid(s)      => Task.now(Invalid(s))    
                case Valid(Right(b)) => Task.now(Valid(b))  
                case Valid(Left(a))  => tailRecM(a)(fn).value   
            })  
        }   
    }

Solution

  • Task has its own tailRecM, so it makes sense to use it. Try

    def tailRecM[A, B](init: A)(fn: A => TaskValidated[Either[A, B]]): TaskValidated[B] = {
      def aux(fn: A => Task[Validated[String, Either[A, B]]]): Task[Validated[String, B]] = {
        def fn1(a: A): Task[Either[A, B]] = fn(a).flatMap {
          case Valid(either) => Task.now(either)
          case Invalid(s)    => Task.raiseError(new RuntimeException(s))
        }
    
        Task.tailRecM(init)(fn1).map(Valid(_))
      }
    
      TaskValidated(aux(fn(_).value))
    }