I created a function that combs strings for mentions of countries. This is based on a .txt file that contains many different way people mention a country in the text. The file looks like this:
"afghanistan": ["afghan", "afghans"], "albania": ["albanian", "albanians"], "algeria": ["algerian", "algerians"], "angola": ["angolan", "angolans"],
...
and so on, for every country on earth.
I then created a function that combs the string and searches for the mentions - but it runs a bit slow on large datasets, and i really want to make the function run faster - but I don't know how.
The function looks like this:
import json
import string
from re import sub
from typing import List, Union
def find_countries(text: str, exclude: Union[str, List[str]] = [], extra: Union[str, List[str]] = []) -> Union[List[str], str]:
"""
Parameters
----------
`text` : `str`
The text to extract countries from.
`exclude` : `list or str`
Optional. Countries to exclude from search.
`extra` : `list or str`
Optional. Additional terms to search for (usually orgs).
"""
# Load country names from file
with open('country_names.txt') as file:
country_names = json.load(file)
# Convert 'exclude' and 'extra' to lists
exclude = [exclude] if isinstance(exclude, str) else exclude
extra = [extra] if isinstance(extra, str) else extra
# Include 'extra' countries or orgs
for i in extra:
country_names[i.lower()] = []
# Remove 'exclude' countries using set operations
exclude_set = set(exclude)
countries = {country for country in country_names.keys() if country.lower() not in exclude_set}
# Clean and preprocess the input text
my_punct = string.punctuation + '”“'
replace_punct_string = "['’-]"
text = sub(replace_punct_string, " ", text)
text = text.translate(str.maketrans('', '', my_punct)).lower()
#Search for country mentions using a set comprehension
countries_mentioned = {country for country in countries
if any(f' {name} ' in f' {text} ' for name in {country} | set(country_names[country]))}
return list(countries_mentioned)
The function recieves a string and combs it for mentions of countries, which it then returns as a list of countries. I usually apply it to a Pandas Series.
I think that code as it is now is "fine" - it isn't long and it does the job. I wonder and hope that you can help me make it run faster so that when i apply it to tens of thousands of texts it wont years to finish. Also - any tips on writing better code will help a lot!
You do a lot of converting on-the-fly which seems to me completely unnecessary. You really should provide things as sets if you use only set functionality. If I'm seeing this correctly you don't need the ordering of the list so just fill sets into the arguments rather than lists. With this you can save all the conversion stuff inside the function.
Additionally, if the file is not too large and you are using the function a lot of times, you can save much performance by loading the data only once globally and saving it in memory instead of reloading it all the time inside the function. You could e.g. create a data structure which loads those data automatically and caches it to prevent reloads. The @property
decorator is well-suited for such use-cases.
I would also create a dictionary which maps the variants to the correct value. Something like
{
"afghan": "afghanistan",
"afghans": "afghanistan",
# ...
}
With this you can save / outsource one loop in your function.
One warning though: You should almost never use an empty list in the argument list as default value. Here is why - found at this SO Post
Actually, this "flipped" dictionary is not helpful. As you mentioned there was also a problem with matching subwords e.g. Oman
in woman
. You can prevent this and eventually even speed things up a bit using regex (I don't actually know, didn't do a performance test).
import itertools
import json
from typing import Optional, Iterable
from regex import regex
class CountryProvider:
def __init__(self):
self._countries: Optional[set[str]] = None
self._patterns: Optional[dict[str, regex.Pattern]] = None
def _load_countries(self):
with open("country_names.txt") as file:
countries = json.load(file)
self._patterns = {
country: regex.compile(
rf"\b({country}|" + "|".join(variants) + r")\b", regex.IGNORECASE
)
for country, variants in countries
}
self._countries = set(countries.keys())
@property
def countries(self) -> set[str]:
if self._countries is None:
self._load_countries()
return self._countries
@property
def patterns(self) -> dict[str, regex.Pattern]:
if self._patterns is None:
self._load_countries()
return self._patterns
COUNTRY_PROVIDER = CountryProvider()
def find_countries(
text: str,
exclude: Optional[list[str]] = None,
extra: Optional[dict[str, list[str]]] = None,
) -> list[str]:
# preprocess text input
# set empty list for exclude and extra if they are None
countries = COUNTRY_PROVIDER.countries
patterns = COUNTRY_PROVIDER.patterns
extra_patterns = {
country: regex.compile(
rf"\b({country}|" + "|".join(variants) + r")\b", regex.IGNORECASE
)
for country, variants in extra
if country not in exclude
}
mentioned_countries: list[str] = []
for country in countries:
if country in exclude:
continue
if regex.search(patterns[country], text, regex.IGNORECASE) is not None:
mentioned_countries.append(country)
for country in extra:
if regex.search(extra_patterns[country], text, regex.IGNORECASE) is not None:
mentioned_countries.append(country)
return mentioned_countries
Note that the patterns dictionary contains a regex pattern for each country which should match all variants.