apache-sparkpyspark

How to use PySpark UDF in Java / Scala Spark project


There are a lot of questions: "How to call Java code from PySpark" and no one about Python code calling from Java Spark project. It's useful for big old Java projects, that requires a functionality, that was implemented in Python.


Solution

  • I've also shared the answer in my Medium.

    As you know, Apache Spark is written in Scala. PySpark is not a separate full-python project. There is org.apache.spark.deploy.PythonRunner class that:

    In its turn, when Spark Context is being created in the Python script, it connects to the Py4J server using credentials from the environment variables. Py4J allows you to use any JVM object via the Java Reflection API. In other words, PySpark is a wrapper of the Java Spark Context.

    Example of a simple Java app that uses Apache Spark’s Python Runner:

    package example.python;
    
    import org.apache.spark.deploy.PythonRunner;
    import org.apache.spark.sql.SparkSession;
    
    public class Main {
    
        public static void main(String[] args) {
            SparkSession spark = SparkSession.builder()
                    .appName("Shared Spark Context Example")
                    .master("local[*]")
                    .getOrCreate();
            spark.sparkContext().setLogLevel("ERROR");
    
            PythonRunner.main(new String[]{
                    "src/main/python/example.py",
                    "src/main/python/example.py"
            });
    
            spark.stop();
        }
    }
    

    But if you try to initialize a Spark Session in example.py, you get the exception: there is can be only one SparkContext in a JVM process. So the first question is: how to put an existing Java SparkContext into PySpark? And the next question is: how to share a DataFrame with PySpark?

    To share an existing SparkContext, it’s needed to make a connection to the JVM over the Py4J gateway, provide access to an instance of org.apache.spark.api.java.JavaSparkContext by a public static variable, and initialize pyspark.conf.SparkConf by JavaSparkContext#getConf().

    Sharing of a DataFrame is possible with Spark’s table temporary view functionality.

    Here is the updated code for Java:

    package example.python;
    
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.deploy.PythonRunner;
    import org.apache.spark.sql.SparkSession;
    
    public class Main {
    
        public static JavaSparkContext jsc;
    
        public static void main(String[] args) {
            SparkSession spark = SparkSession.builder()
                    .appName("Spark Python Runner")
                    .master("local[*]")
                    .getOrCreate();
            spark.sparkContext().setLogLevel("ERROR");
            jsc = new JavaSparkContext(spark.sparkContext());
    
            var df = spark.read().textFile("src/main/resources/dataset.txt");
            df.createOrReplaceTempView("tbl");
    
    
            PythonRunner.main(new String[]{
                    "src/main/python/example.py",
                    "src/main/python/example.py"
            });
    
            spark.sql("SELECT * FROM tbl").show();
    
            spark.stop();
        }
    }
    

    And Python:

    #!/usr/bin/env python
    # coding: utf-8
    import sys
    
    import pyspark
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import length, udf
    from pyspark.sql.types import StringType
    
    if __name__ == "__main__":
        gateway = pyspark.java_gateway.launch_gateway()
        jsc = gateway.jvm.example.python.Main.jsc
        conf = pyspark.conf.SparkConf(True, gateway.jvm, jsc.getConf())
    
        sc = pyspark.SparkContext(gateway=gateway, jsc=jsc, conf=conf)
        spark = SparkSession(sc)
    
        df = spark.sql("SELECT * FROM tbl")
    
        df = df.withColumn("len", length('value').alias('len'))
    
        df.createOrReplaceTempView("tbl")
    
        sys.exit(0)
    

    Even more. It’s possible to register a PythonUDF in PySpark and call it in Java code afterwise.

    Python:

    # ...
    py_concat_of2_udf = udf(lambda x, y: str(x) + str(y), StringType())
    spark.udf.register("py_concat_of2", py_concat_of2_udf)
    # ...
    

    Java:

    // ...
    spark.sql("SELECT * FROM tbl")
            .withColumn("pyfunc", callUDF("py_concat_of2", col("value"), col("len")))
            .show();
    // ...
    

    The stdout of the code:

    +----------+---+------------+
    |     value|len|      pyfunc|
    +----------+---+------------+
    |       one|  3|        one3|
    |       two|  3|        two3|
    |three four| 10|three four10|
    |      five|  4|       five4|
    +----------+---+------------+
    

    How does it work? There is org.apache.spark.sql.catalyst.expressions.PythonUDF Scala class that contains an org.apache.spark.api.python.PythonFunction object. The object contains command: Seq[Byte] variable, which is actually a Python lambda serialized by Pickle.

    The negative side of this approach is a stacktrace in stdout for each action:

    ERROR DAGScheduler: Failed to update accumulator 37 (org.apache.spark.api.python.PythonAccumulatorV2) for task 0
    java.net.ConnectException: Connection refused
     at java.base/sun.nio.ch.Net.connect0(Native Method)
     at java.base/sun.nio.ch.Net.connect(Net.java:579)
     at java.base/sun.nio.ch.Net.connect(Net.java:568)
     at java.base/sun.nio.ch.NioSocketImpl.connect(NioSocketImpl.java:588)
     at java.base/java.net.SocksSocketImpl.connect(SocksSocketImpl.java:327)
     at java.base/java.net.Socket.connect(Socket.java:633)
     at java.base/java.net.Socket.connect(Socket.java:583)
     at java.base/java.net.Socket.<init>(Socket.java:507)
     at java.base/java.net.Socket.<init>(Socket.java:287)
     at org.apache.spark.api.python.PythonAccumulatorV2.openSocket(PythonRDD.scala:701)
     at org.apache.spark.api.python.PythonAccumulatorV2.merge(PythonRDD.scala:723)
     at org.apache.spark.scheduler.DAGScheduler.$anonfun$updateAccumulators$1(DAGScheduler.scala:1610)
     at org.apache.spark.scheduler.DAGScheduler.$anonfun$updateAccumulators$1$adapted(DAGScheduler.scala:1601)
     at scala.collection.immutable.List.foreach(List.scala:333)
     at org.apache.spark.scheduler.DAGScheduler.updateAccumulators(DAGScheduler.scala:1601)
     at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:1749)
     at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2857)
     at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
     at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
     at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    

    The org.apache.spark.api.python.PythonAccumulatorV2 object is created by pyspark.SparkContext and is used for Apache Spark metrics.