My use case is as follows: I have a server that handles requests, and for each request I'd like the log record to contain a user_id
. I'd like to make it as seamless as possible for other developers in my team, such that they can simply import logging
and use it, without passing user_id around.
Here's a MWE, but, as you can see, its not always working:
import logging
import threading
import time
class ContextFilter(logging.Filter):
def __init__(self, user_id: str):
super().__init__()
self.local = threading.local()
self.local.user_id = user_id
def filter(self, record):
record.user_id = getattr(self.local, 'user_id', 'NoValue')
return True # Returning True ensures the log message is processed
# Set up logging
logger = logging.getLogger("my_logger")
logger.setLevel(logging.DEBUG)
# Add a handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(user_id)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
# logger.info("This is a test message.")
def worker(user_id: str):
logger.addFilter(ContextFilter(user_id=user_id))
logger.info("message 1")
time.sleep(0.5)
logger.info("message 2")
t1 = threading.Thread(target=worker, args=("user1",))
t2 = threading.Thread(target=worker, args=("user2",))
t1.start()
t2.start()
t1.join()
t2.join()
This is the output:
2024-11-20 17:46:39,780 - INFO - user1 - message 1
2024-11-20 17:46:39,780 - INFO - user2 - message 1
2024-11-20 17:46:40,284 - INFO - NoValue - message 2
2024-11-20 17:46:40,285 - INFO - user2 - message 2
What am I missing here?
Each thread is creating its own thread-local object, which is wrong, a thread-local object should only be created once then accessed from many threads.
Just have one filter, and have each thread modify its thread-local version of this id. (Although all threads read and write to the same threading.local()
each thread sees different attributes if it reads or write to it, that's what thread-local means)
import logging
import threading
import time
class ContextFilter(logging.Filter):
def __init__(self):
super().__init__()
self.local = threading.local()
def set_id(self, user_id):
self.local.user_id = user_id
def filter(self, record):
record.user_id = getattr(self.local, 'user_id', 'NoValue')
return True # Returning True ensures the log message is processed
# Set up logging
logger = logging.getLogger("my_logger")
logger.setLevel(logging.DEBUG)
# Add a handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(user_id)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
global_filter = ContextFilter()
logger.addFilter(global_filter)
def worker(user_id: str):
global_filter.set_id(user_id)
logger.info("message 1")
time.sleep(0.5)
logger.info("message 2")
t1 = threading.Thread(target=worker, args=("user1",))
t2 = threading.Thread(target=worker, args=("user2",))
t1.start()
t2.start()
t1.join()
t2.join()
INFO - user1 - message 1
INFO - user2 - message 1
INFO - user1 - message 2
INFO - user2 - message 2
(time stamp omitted for privacy)