scala

nested case class generic filter method


I have some nested case classes that may look like the following:

case class Lowest(intVal: Int, stringVal: String)
case class Mid   (lowestSeq: Seq[Lowest])
case class High  (midSeq: Seq[Mid])

So if I had a class of Mid I could easily do

myMid.lowestSeq.filter(_.intVal == 42)

but I would like to create a custom filter method where I can filter Lowest elements based on a variable. I don't mind writing some custom method like the following

case class Mid   (lowestSeq: Seq[Lowest]){
  def filterLowest(/*predicate*/): Seq[Lowest] = {
    lowestSeq.filter(/*predicate*/)
  }
}

case class High   (midSeq: Seq[Mid]){
  def filterLowest(/*predicate*/): Seq[Lowest] = {
    midSeq.foldLeft[Seq[Lowest]](Seq.empty){case => (acc, mid)
      acc :+ mid.filterLowest(/*predicate*/)
    }
  }
}

But I'm having a hard time understanding how I define the predicate

For example if calling on an instance of High I want to be able to do the following

val allIntValsEq42    = myHigh.filterLowest(/*???*/intVal == 42)
val allStringValsEqYo = myHigh.filterLowest(/*???*/stringVal == 42)

How can you pass in a predicate in this way for the Mid or High class?


Solution

  • Well all you need is a predicate for Lowest and a predicate is nothing more than a function that returns Boolean so all you need is:

    predicate: Lowest => Boolean
    

    But, we can improve your setup a little to make this more efficient:

    final case class Lowest (intVal: Int, stringVal: String)
    final case class Mid (lowestSeq: Seq[Lowest])
    
    final case class High (midSeq: Seq[Mid]) {
      def filterLowest(predicate: Lowest => Boolean): Seq[Lowest] =
        midSeq.flatMap(id => mid.lowestSeq.filter(predicate))
    }
    

    Which then you can use just like when filtering the Seq directly like:

    // Using placeholder syntax:
    myHigh.filterLowest(_.stringVal == 42)
    
    // Using full lambda syntax:
    myHigh.filterLowest(low => low.stringVal == 42)
    
    // Using an auxiliary function:
    def lowChecker(low: Lowest): Boolean = ???
    myHigh.filterLowest(lowChecker)