I am new to cats-effect and I am trying to implement the classical expression evaluation using cats-effect. Using eval I would like to return an IO[Double] instead of Double. I have my naive code below but of course it doesnt type check. What is the right way to approach this? (It seems like generally with pattern matching it is difficult with IOs).
import cats.effect._
import cats.effect.unsafe.implicits.global
sealed trait Expression
case class Add(x: Expression, y: Expression) extends Expression
case class Mult(x: Expression, y: Expression) extends Expression
case class Exp(x: Expression) extends Expression
case class Const(x: Double) extends Expression
extension (exp: Expression)
def +(other: Expression) = Add(exp,other)
def *(other: Expression) = Mult(exp,other)
def eval(exp: Expression): IO[Double] = IO{
exp match
case Add(x, y) => eval(x) + eval(y) // This does not type check
case Mult(x, y) => eval(x) * eval(y)
case Exp(x) => scala.math.exp(eval(x))
case Const(x) => x
}
val expression1 = Exp((Const(1) + Const(2)) * Const(9))
@main def main =
println(eval(expression1).unsafeRunSync())
IO
is a monad. Try for-comprehensions
def eval(exp: Expression): IO[Double] =
exp match
case Add(x, y) => for {
x1 <- eval(x)
y1 <- eval(y)
} yield x1 + y1
case Mult(x, y) => for {
x1 <- eval(x)
y1 <- eval(y)
} yield x1 * y1
case Exp(x) => for {
x1 <- eval(x)
} yield scala.math.exp(x1)
case Const(x) => IO(x)
or applicative syntax
import cats.syntax.apply.given
def eval(exp: Expression): IO[Double] =
exp match
case Add(x, y) => (eval(x), eval(y)).mapN(_ + _)
case Mult(x, y) => (eval(x), eval(y)).mapN(_ * _)
case Exp(x) => eval(x).map(scala.math.exp)
case Const(x) => IO(x)
or to define an instance of the type class Numeric
import cats.syntax.apply.given
import Numeric.Implicits.given
given [A: Numeric]: Numeric[IO[A]] = new Numeric[IO[A]]:
override def plus(x: IO[A], y: IO[A]): IO[A] = (x, y).mapN(_ + _)
override def times(x: IO[A], y: IO[A]): IO[A] = (x, y).mapN(_ * _)
override def minus(x: IO[A], y: IO[A]): IO[A] = ???
override def negate(x: IO[A]): IO[A] = ???
override def fromInt(x: Int): IO[A] = ???
override def parseString(str: String): Option[IO[A]] = ???
override def toInt(x: IO[A]): Int = ???
override def toLong(x: IO[A]): Long = ???
override def toFloat(x: IO[A]): Float = ???
override def toDouble(x: IO[A]): Double = ???
override def compare(x: IO[A], y: IO[A]): Int = ???
def eval(exp: Expression): IO[Double] =
exp match
case Add(x, y) => eval(x) + eval(y)
case Mult(x, y) => eval(x) * eval(y)
case Exp(x) => eval(x).map(scala.math.exp)
case Const(x) => IO(x)
or to define your own syntax
import cats.syntax.apply.given
import Numeric.Implicits.given
extension [A: Numeric](x: IO[A])
def +(y: IO[A]): IO[A] = (x, y).mapN(_ + _)
def *(y: IO[A]): IO[A] = (x, y).mapN(_ * _)
def exp(x: IO[Double]): IO[Double] = x.map(scala.math.exp)
def eval(expr: Expression): IO[Double] =
expr match
case Add(x, y) => eval(x) + eval(y)
case Mult(x, y) => eval(x) * eval(y)
case Exp(x) => exp(eval(x))
case Const(x) => IO(x)
or just
def eval(expr: Expression): IO[Double] =
def eval0(expr: Expression): Double =
expr match
case Add(x, y) => eval0(x) + eval0(y)
case Mult(x, y) => eval0(x) * eval0(y)
case Exp(x) => scala.math.exp(eval0(x))
case Const(x) => x
IO(eval0(expr))
end eval