pysparkdatabricksazure-databricksdatabricks-connect

Can't repartition rdd when connecting with databricks-connect


When connecting to a databricks cluster with databricks-connect, I get a Py4JJavaError exception when I do a repartition on a simple rdd:

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
rdd = spark.sparkContext.parallelize(range(0, 10), 3)
print(rdd.sum())
print(rdd.repartition(5).sum())

The first print statement gets executed fine and prints 45, but the second print statement fails with the following error:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
 in 
      5 rdd = spark.sparkContext.parallelize(range(0, 10), 3)
      6 print(rdd.sum())
----> 7 print(rdd.repartition(5).sum())

d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\rdd.py in sum(self)
   1256         6.0
   1257         """
-> 1258         return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
   1259 
   1260     def count(self):

d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\rdd.py in fold(self, zeroValue, op)
   1110         # zeroValue provided to each partition is unique from the one provided
   1111         # to the final reduce call
-> 1112         vals = self.mapPartitions(func).collect()
   1113         return reduce(op, vals, zeroValue)
   1114 

d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\rdd.py in collect(self)
    964         # Default path used in OSS Spark / for non-credential passthrough clusters:
    965         with SCCallSiteSync(self.context) as css:
--> 966             sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
    967         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    968 

d:\source\repos\...\.venv_3_8_10\lib\site-packages\py4j\java_gateway.py in __call__(self, *args)
   1302 
   1303         answer = self.gateway_client.send_command(command)
-> 1304         return_value = get_return_value(
   1305             answer, self.gateway_client, self.target_id, self.name)
   1306 

d:\source\repos\...\.venv_3_8_10\lib\site-packages\pyspark\sql\utils.py in deco(*a, **kw)
    115     def deco(*a, **kw):
    116         try:
--> 117             return f(*a, **kw)
    118         except py4j.protocol.Py4JJavaError as e:
    119             converted = convert_exception(e.java_exception)

d:\source\repos\...\.venv_3_8_10\lib\site-packages\py4j\protocol.py in get_return_value(answer, gateway_client, target_id, name)
    324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325             if answer[1] == REFERENCE_TYPE:
--> 326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
    328                     format(target_id, ".", name), value)

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: java.lang.ClassCastException: cannot assign instance of org.apache.spark.serializer.KryoSerializer to field org.apache.spark.ShuffleDependency.rowBasedChecksums of type [Lorg.apache.spark.shuffle.checksum.RowBasedChecksum; in instance of org.apache.spark.ShuffleDependency
    at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2301)
    at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1431)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2437)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
    at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
    at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
    at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
    at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:527)
    at sun.reflect.GeneratedMethodAccessor257.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2322)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.readArray(ObjectInputStream.java:2119)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1657)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461)
    at org.apache.spark.sql.util.ProtoSerializer.$anonfun$deserializeObject$1(ProtoSerializer.scala:7058)
    at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
    at org.apache.spark.sql.util.ProtoSerializer.deserializeObject(ProtoSerializer.scala:7043)
    at com.databricks.service.SparkServiceRPCHandler.execute0(SparkServiceRPCHandler.scala:728)
    at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC0$1(SparkServiceRPCHandler.scala:477)
    at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
    at com.databricks.service.SparkServiceRPCHandler.executeRPC0(SparkServiceRPCHandler.scala:372)
    at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:323)
    at com.databricks.service.SparkServiceRPCHandler$$anon$2.call(SparkServiceRPCHandler.scala:309)
    at java.util.concurrent.FutureTask.run(FutureTask.java:266)
    at com.databricks.service.SparkServiceRPCHandler.$anonfun$executeRPC$1(SparkServiceRPCHandler.scala:359)
    at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62)
    at com.databricks.service.SparkServiceRPCHandler.executeRPC(SparkServiceRPCHandler.scala:336)
    at com.databricks.service.SparkServiceRPCServlet.doPost(SparkServiceRPCServer.scala:167)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:523)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:590)
    at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:799)
    at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:550)
    at org.eclipse.jetty.server.handler.ScopedHandler.nextScope(ScopedHandler.java:190)
    at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:501)
    at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)
    at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:127)
    at org.eclipse.jetty.server.Server.handle(Server.java:516)
    at org.eclipse.jetty.server.HttpChannel.lambda$handle$1(HttpChannel.java:388)
    at org.eclipse.jetty.server.HttpChannel.dispatch(HttpChannel.java:633)
    at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:380)
    at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:277)
    at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:311)
    at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:105)
    at org.eclipse.jetty.io.ChannelEndPoint$1.run(ChannelEndPoint.java:104)
    at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.runTask(EatWhatYouKill.java:338)
    at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.doProduce(EatWhatYouKill.java:315)
    at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.tryProduce(EatWhatYouKill.java:173)
    at org.eclipse.jetty.util.thread.strategy.EatWhatYouKill.run(EatWhatYouKill.java:131)
    at org.eclipse.jetty.util.thread.ReservedThreadExecutor$ReservedThread.run(ReservedThreadExecutor.java:386)
    at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:883)
    at org.eclipse.jetty.util.thread.QueuedThreadPool$Runner.run(QueuedThreadPool.java:1034)
    at java.lang.Thread.run(Thread.java:750)

Instead of the error, I would like the second print statement to get executed and print 45. I can run the snippet fine in a databricks notebook in the webinterface.

I can do repartition on dataframes fine without errors. The below snippet runs fine with databricks connect and prints 45:

import pyspark.sql.functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
df = spark.range(0, 10)
df.agg(F.sum("id")).collect()[0]["sum(id)"]
df_repar = df.repartition(5)
print(df_repar.agg(F.sum("id")).collect()[0]["sum(id)"])

I have a quite big setup using dataframes that runs without any apparent issues. I would like to move to databricks runtime 11.3 LTS but this issue is preventing me from upgrading.

I run python 3.8.10 and have asserted that version numbers of the packages on the cluster match the locally installed ones. I run databricks-connect==10.4.22 and connect to a databricks cluster running databricks runtime 10.4 LTS.


Solution

  • This issue was resolved by upgrading to the newly released databricks-connect==10.4.25.