scalaapache-sparkscala-quasiquotesquasiquotesscalafix

Quasiquotes in Scalafix


Here is Spark 2.4 code using unionAll

import org.apache.spark.sql.{DataFrame, Dataset}

object UnionRewrite {
  def inSource(
    df1: DataFrame,
    df2: DataFrame,
    df3: DataFrame,
    ds1: Dataset[String],
    ds2: Dataset[String]
  ): Unit = {
    val res1 = df1.unionAll(df2)
    val res2 = df1.unionAll(df2).unionAll(df3)
    val res3 = Seq(df1, df2, df3).reduce(_ unionAll _)
    val res4 = ds1.unionAll(ds2)
    val res5 = Seq(ds1, ds2).reduce(_ unionAll _)
  }
}

In Spark 3.+ unionAll is deprecated. Here is equivalent code using union

import org.apache.spark.sql.{DataFrame, Dataset}

object UnionRewrite {
  def inSource(
    df1: DataFrame,
    df2: DataFrame,
    df3: DataFrame,
    ds1: Dataset[String],
    ds2: Dataset[String]
  ): Unit = {
    val res1 = df1.union(df2)
    val res2 = df1.union(df2).union(df3)
    val res3 = Seq(df1, df2, df3).reduce(_ union _)
    val res4 = ds1.union(ds2)
    val res5 = Seq(ds1, ds2).reduce(_ union _)
  }
}

The question is how to write a Scalafix rule (using quasiquotes) replacing unionAll with union?

Without quasiquotes I implemented the rule, it's working

override def fix(implicit doc: SemanticDocument): Patch = {
  def matchOnTree(t: Tree): Patch = {
    t.collect {
      case Term.Apply(
          Term.Select(_, deprecated @ Term.Name(name)),
          _
          ) if config.deprecatedMethod.contains(name) =>
        Patch.replaceTree(
          deprecated,
          config.deprecatedMethod(name)
        )
      case Term.Apply(
          Term.Select(_, _ @Term.Name(name)),
          List(
            Term.AnonymousFunction(
              Term.ApplyInfix(
                _,
                deprecatedAnm @ Term.Name(nameAnm),
                _,
                _
              )
            )
          )
          ) if "reduce".contains(name) && config.deprecatedMethod.contains(nameAnm) =>
        Patch.replaceTree(
          deprecatedAnm,
          config.deprecatedMethod(nameAnm)
        )
    }.asPatch
  }

  matchOnTree(doc.tree)
}

Solution

  • Try the rule

    override def fix(implicit doc: SemanticDocument): Patch = {
    
      def isDatasetSubtype(expr: Tree): Boolean =
        expr.symbol.info.flatMap(_.signature match {
          case ValueSignature(tpe)        => Some(tpe)
          case MethodSignature(_, _, tpe) => Some(tpe)
          case _                          => None
        }) match {
          case Some(TypeRef(_, symbol, _)) =>
            Seq("package.DataFrame", "Dataset")
              .map(tp => Symbol(s"org/apache/spark/sql/$tp#"))
              .contains(symbol)
          case _ => false
        }
    
      def mkPatch(ename: Tree): Patch = Patch.replaceTree(ename, "union")
    
      def matchOnTree(t: Tree): Patch =
        t.collect {
            case q"$expr.${ename@q"unionAll"}($expr1)" if isDatasetSubtype(expr) =>
              mkPatch(ename)
    
            // infix application
            case q"$expr ${ename@q"unionAll"} $expr1" /*if isDatasetSubtype(expr)*/ =>
              mkPatch(ename)
        }.asPatch
    
      matchOnTree(doc.tree)
    }
    

    It transforms

    import org.apache.spark.sql.{DataFrame, Dataset}
    
    object UnionRewrite {
      def inSource(
                    df1: DataFrame,
                    df2: DataFrame,
                    df3: DataFrame,
                    ds1: Dataset[String],
                    ds2: Dataset[String]
                  ): Unit = {
        val res1 = df1.unionAll(df2)
        val res2 = df1.unionAll(df2).unionAll(df3)
        val res3 = Seq(df1, df2, df3).reduce(_ unionAll _)
        val res4 = ds1.unionAll(ds2)
        val res5 = Seq(ds1, ds2).reduce(_ unionAll _)
        val res6 = Seq(ds1, ds2).reduce(_ unionAll (_))
    
        val unionAll = 42
      }
    }
    

    into

    import org.apache.spark.sql.{DataFrame, Dataset}
    
    object UnionRewrite {
      def inSource(
                    df1: DataFrame,
                    df2: DataFrame,
                    df3: DataFrame,
                    ds1: Dataset[String],
                    ds2: Dataset[String]
                  ): Unit = {
        val res1 = df1.union(df2)
        val res2 = df1.union(df2).union(df3)
        val res3 = Seq(df1, df2, df3).reduce(_ union _)
        val res4 = ds1.union(ds2)
        val res5 = Seq(ds1, ds2).reduce(_ union _)
        val res6 = Seq(ds1, ds2).reduce(_ union (_))
    
        val unionAll = 42
      }
    }
    

    https://scalacenter.github.io/scalafix/docs/developers/setup.html

    https://scalameta.org/docs/trees/quasiquotes.html

    https://scalameta.org/docs/semanticdb/guide.html

    Your Ver: 1 implementation erroneously transformed val unionAll = 42 into val union = 42.

    Sadly, <: Dataset[_] can't be checked for the infix application since SemanticDB seems not to have type information in this case (underscore _ in a lambda). This seems to be SemanticDB limitation. If you really needed subtype check in this case then maybe you would need a compiler plugin.


    Update. We can use multiple rules: firstly apply the rule replacing underscore lambdas with parameter lambdas

    override def fix(implicit doc: SemanticDocument): Patch = {
      def matchOnTree(t: Tree): Patch =
        t.collect {
          case t1@q"_.unionAll(_)" =>
            Patch.replaceTree(t1, "(x, y) => x.unionAll(y)")
          case t1@q"_ unionAll _" =>
            Patch.replaceTree(t1, "(x, y) => x unionAll y")
        }.asPatch
    
      matchOnTree(doc.tree)
    }
    

    then re-compile the code (new .semanticdb files will be generated), apply the second rule replacing unionAll with union (if types correspond)

    override def fix(implicit doc: SemanticDocument): Patch = {
    
      def isDatasetSubtype(expr: Tree): Boolean = {
        expr.symbol.info.flatMap(_.signature match {
          case ValueSignature(tpe)        => Some(tpe)
          case MethodSignature(_, _, tpe) => Some(tpe)
          case _                          => None
        }) match {
          case Some(TypeRef(_, symbol, _)) =>
            Seq("package.DataFrame", "Dataset")
              .map(tp => Symbol(s"org/apache/spark/sql/$tp#"))
              .contains(symbol)
          case _ => false
        }
      }
    
      def mkPatch(ename: Tree): Patch = Patch.replaceTree(ename, "union")
    
      def matchOnTree(t: Tree): Patch =
        t.collect {
          case q"$expr.${ename@q"unionAll"}($_)" if isDatasetSubtype(expr) =>
            mkPatch(ename)
          case q"$expr ${ename@q"unionAll"} $_" if isDatasetSubtype(expr) =>
            mkPatch(ename)
        }.asPatch
    
      matchOnTree(doc.tree)
    }
    

    then apply the third rule replacing parameter lambdas back with underscore lambdas

    override def fix(implicit doc: SemanticDocument): Patch = {
      def matchOnTree(t: Tree): Patch =
        t.collect {
          case t1@q"(x, y) => x.union(y)" =>
            Patch.replaceTree(t1, "_.union(_)")
          case t1@q"(x, y) => x union y" =>
            Patch.replaceTree(t1, "_ union _")
        }.asPatch
    
      matchOnTree(doc.tree)
    }
    

    The 1st and 3rd rules can be syntactic.