I'm trying to get the country name with latitude and longitude as input, so I used the Nominatim API and when I pass as a UDF it works, but when I try to use pandas_udf get the following error:
An exception was thrown from a UDF: 'RuntimeError: Result vector from pandas_udf was not the required length: expected 1, got 2'
This is my code
import requests
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType, udf
@pandas_udf("string", PandasUDFType.SCALAR)
def country_name(lat, lon):
url = f"https://nominatim.openstreetmap.org/reverse?format=json&lat={lat}&lon={lon}"
response = requests.get(url)
data = response.json()
if 'error' in data:
return 'NA'
else:
return data['address']['country']
df = spark.createDataFrame([(40.730610, -73.935242)], ["lat", "lon"])
df = df.withColumn("country", country_name(df["lat"], df["lon"]))
df.show()
As I say if I use a regular UDF it works, the problem is when I try to use pandas_udf.
Refer the Series to Scalar section of pandas_udf API guide.
Change your code as follows as per the sample example given in above guide. The changes are marked with comment # Changed
.
import requests
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType, udf
@pandas_udf("string") # Changed
def country_name(lat: pd.Series, lon: pd.Series) -> str : # Changed
url = f"https://nominatim.openstreetmap.org/reverse?format=json&lat={lat[0]}&lon={lon[0]}" # Changed
response = requests.get(url)
data = response.json()
if 'error' in data:
return 'NA'
else:
return data['address']['country']
df = spark.createDataFrame([(40.730610, -73.935242)], ["lat", "lon"])
# df = df.withColumn("country", country_name(df["lat"], df["lon"])) # Changed
df = df.select(country_name(df["lat"], df["lon"])) # Changed
df.show()
However, this strategy works if the function is expected to return a scalar (single) value for the given input series.
In real data, you would be expecting vetorization i.e for given lat-lon dataframe, you need a series of results. For this, the API should support list of lat-lon pairs. If not, then as you can see in the following code, you need to call the API for each lat, lon value, hence defeating the purpose of the vectorization achieved through pandas_udf
.
import requests
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType, udf
def call_url(row):
url = f"https://nominatim.openstreetmap.org/reverse?format=json&lat={row['lat']}&lon={row['lon']}"
response = requests.get(url)
data = response.json()
if 'error' in data:
return 'NA'
else:
return data['address']['country']
@pandas_udf("string")
def country_name(lat: pd.Series, lon: pd.Series) -> pd.Series :
lat_lon_df = pd.DataFrame({"lat": lat, "lon": lon})
lat_lon_df["country"] = lat_lon_df.apply(call_url, axis=1)
return lat_lon_df["country"]
df = spark.createDataFrame([(40.730610, -73.935242), (45.0, -75.0)], ["lat", "lon"])
df = df.withColumn("country", country_name(df["lat"], df["lon"]))
df.show()
Output:
+--------+----------+-------------+
| lat| lon| country|
+--------+----------+-------------+
|40.73061|-73.935242|United States|
| 45.0| -75.0| Canada|
+--------+----------+-------------+