I have a simple usecase - I want to perform some operation on each file in a directory. For example, let's just say I want to print the file name. A very simple single threaded option for this is below -
Single threaded implementation
use std::{env, fs, path::Path, rc::Rc};
fn main() {
let args = env::args().collect::<Vec<String>>();
let process: ProcessFunc = Rc::new(|p| println!("{p}"));
process_file_path(String::from("/playground/target/release/deps"), Rc::clone(&process));
}
type ProcessFunc = Rc<dyn Fn(String) + Send + 'static>;
fn process_file_path(source: String, f: ProcessFunc) {
let source_path = Path::new(&source)
.canonicalize()
.expect("invalid source path");
let source_path_str = source_path.as_path().to_str().unwrap().to_string();
f(source_path_str);
if source_path.is_dir() {
for child in fs::read_dir(source_path).unwrap() {
let child = child.unwrap();
process_file_path(
child.path().as_path().to_str().unwrap().to_string(),
Rc::clone(&f),
)
}
}
}
Since no file has dependency on any other file, I want to make it concurrent. To do that I am planning to use a thread-pool where each thread would wait on a channel to receive file paths and then perform the processing function on it.
The part I am not able to figure out is the termination of threads. The complexity is because of the fact that each job, i.e. path, itself can enqueue more jobs, i.e. one for path of each child of current path. So the question is how would threads know when to stop?
I can think of two approaches -
I tried to use two channels so that the part of deciding whether more jobs are there can be delegated to a dedicated single thread (main thread) but realized soon that the problem remains the same, i.e. how do we know that all jobs are complete.
Multi threaded implementation (Has only threadpool implementation with job and result channels but no usage in main function as I am not able to figure out).
struct ThreadPool<T> {
job_sender: Option<Sender<Box<dyn FnOnce() -> T + Send + 'static>>>,
handles: Vec<JoinHandle<()>>,
}
impl<T> ThreadPool<T>
where
T: Send + 'static,
{
fn new(cap: usize) -> (Self, Receiver<T>) {
assert_ne!(cap, 0);
let (job_sender, job_receiver) = channel::<Box<dyn FnOnce() -> T + Send + 'static>>();
let (result_sender, result_receiver) = channel::<T>();
let result_sender = Arc::new(Mutex::new(result_sender));
let job_receiver = Arc::new(Mutex::new(job_receiver));
let mut handles = vec![];
for _i in 0..cap {
let job_receiver = Arc::clone(&job_receiver);
let result_sender = Arc::clone(&result_sender);
handles.push(thread::spawn(move || {
while let Ok(job) = job_receiver.lock().unwrap().recv() {
result_sender.lock().unwrap().send(job()).unwrap();
}
}));
}
(
ThreadPool {
handles,
job_sender: Some(job_sender),
},
result_receiver,
)
}
fn add(&self, job: Box<dyn FnOnce() -> T + Send + 'static>) {
self.job_sender.as_ref().unwrap().send(job).unwrap();
}
}
impl<T> Drop for ThreadPool<T> {
fn drop(&mut self) {
self.job_sender = None;
while let Some(handle) = self.handles.pop() {
handle.join().unwrap();
}
}
}
The main problem is that the ThreadPool
will not know when all jobs are done. One way to make it know this is to make sure all instances of Sender<Box<dyn FnOnce() -> T + Send + 'static>>
are dropped when all jobs are done. This can be used because the receiver will return an error when (and only when) all senders are dropped, which in turn will terminate all threads. We can make sure they are dropped at an appropriate time by only using the sender in the jobs being called, in this case in process_file_path
.
One way to achieve this is to not have the sender inside the ThreadPool
, because then there will always be a sender that will not be dropped. To keep the abstraction of the ThreadPool
I will create a separate newtype struct (JobSender
) that contains the senders. This makes it easy to change the internals of how jobs are created without changing all code using it.
(playground)
fn main() {
let process: ProcessFunc<()> = Arc::new(|p| println!("{p}"));
let (_thread_pool, _results, job_spawner) = ThreadPool::new(2);
let job_sender = Arc::new(job_spawner);
process_file_path(String::from("."), Arc::clone(&process), &job_sender);
}
type ProcessFunc<T> = Arc<dyn Fn(String) -> T + Send + Sync + 'static>;
fn process_file_path<T: Send + 'static>(source: String, f: ProcessFunc<T>, job_sender: &Arc<JobSender<()>>) {
let source_path = Path::new(&source)
.canonicalize()
.expect("invalid source path");
let source_path_str = source_path.as_path().to_str().unwrap().to_string();
f(source_path_str);
if source_path.is_dir() {
for child in fs::read_dir(source_path).unwrap() {
let child = child.unwrap();
let thread_pool_clone = job_sender.clone();
let f_clone = f.clone();
job_sender.add(Box::new(move ||{
process_file_path(
child.path().as_path().to_str().unwrap().to_string(),
f_clone,
&thread_pool_clone
)
}))
}
}
}
struct JobSender<T>(Sender<Box<dyn FnOnce() -> T + Send + 'static>>);
impl<T> JobSender<T>
where
T: Send + 'static,
{
fn add(&self, job: Box<dyn FnOnce() -> T + Send + 'static>) {
self.0.send(job).unwrap();
}
}
struct ThreadPool {
handles: Vec<JoinHandle<()>>,
}
impl ThreadPool
{
fn new<T: Send + 'static>(cap: usize) -> (Self, Receiver<T>, JobSender<T>) {
assert_ne!(cap, 0);
let (job_sender, job_receiver) = channel::<Box<dyn FnOnce() -> T + Send + 'static>>();
let (result_sender, result_receiver) = channel::<T>();
let result_sender = Arc::new(Mutex::new(result_sender));
let job_receiver = Arc::new(Mutex::new(job_receiver));
let mut handles = vec![];
for _i in 0..cap {
let job_receiver = Arc::clone(&job_receiver);
let result_sender = Arc::clone(&result_sender);
handles.push(thread::spawn(move || {
while let Ok(job) = job_receiver.lock().unwrap().recv() {
let _ = result_sender.lock().unwrap().send(job());
}
}));
}
(
ThreadPool {
handles,
},
result_receiver,
JobSender(job_sender),
)
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
while let Some(handle) = self.handles.pop() {
handle.join().unwrap();
}
}
}
Note that in this case, if the results from the calculations can actually be useful, the current ThreadPool
implementation needs some changes.