I want to be able to suppress any print to stdout within a specific thread. Here is what I have tried:
import sys, io, time
from threading import Thread
def do_thread_action():
# Disable stdout
sys.stdout = io.StringIO()
print("don't print this 1")
time.sleep(1)
print("don't print this 2")
time.sleep(1)
print("don't print this 3")
# Re-enable stdout
sys.stdout = sys.__stdout__
thread = Thread(target=do_thread_action)
thread.start()
time.sleep(1.5)
# Print this to stdout
print('Print this')
thread.join()
However this does not work because sys.stdout
is global for both thread
and the main thread.
How do I suppress the prints inside do_thread_action
within the thread, but not suppress the prints outside of it?
So, here it is - just replace the sys.stdout
object by an object with a write
(and to encompass all cases, a flush
) methods which can select where the output should go for the currently running thread.
And they can check the currently running thread using the thread name.
Here is an almost "production ready" class which can take care of things, including even the decorators for patching the codepaths which should be guarded for printing:
import time
import threading
import io
from unittest import mock
import sys
delayed_outputs = None
class SelectOutput():
def __init__(self, config):
self.text_io = io.StringIO()
self.ns = threading.local()
self.config = config
def filter(self, func):
def wrapper( *args, **kwargs):
# in this example, thread_id is passed as
# a parameter, but one could also use
thread_id = threading.current_thread().name
if thread_id in self.config and self.config[thread_id] == "capture":
self.ns.stdout = self.text_io
else:
self.ns.stdout = sys.__stdout__
return func(*args, **kwargs)
return wrapper
def instrument(self, func):
def wrapper(*args, **kwargs):
#global all_outputs
with mock.patch("sys.stdout", self):
# all_outputs = tmp
return func(*args, **kwargs)
return wrapper
def write(self, text):
self.ns.stdout.write(text)
def flush(self):
self.ns.stdout.flush()
select_output = SelectOutput(config={"2": "capture"})
@select_output.filter
def target(thread_id):
time.sleep(0.1 * thread_id)
print(f"At thread ID: {thread_id}")
time.sleep(0.2 * thread_id)
print(f"closing thread ID: {thread_id}")
@select_output.instrument
def instrumented():
threads = []
for thread_id in (1, 2, 3):
thread = threading.Thread(target=target, args=(thread_id,))
# name the threads so that each is identified when they are running:
thread.name = str(thread_id)
threads.append(thread)
thread.start()
[t.join() for t in threads]
def main():
global delayed_outputs
delayed_outputs = io.StringIO()
instrumented()
print("Delayed outputs:", select_output.text_io.getvalue())
if __name__ == "__main__":
main()