use std::{ thread, sync::{Arc, Mutex, mpsc}, }; pub trait Task where O: Send + 'static { fn process(&self) -> O; } impl Task for F where F: Fn() -> O, F: Send + 'static, O: Send + 'static { fn process(&self) -> O { self() } } type BoxTask = Box + Send + 'static>; type TaskReceiver = Arc>>>; type TaskSender = mpsc::Sender>; type OutputSender = Arc>>; type OutputReceiver = mpsc::Receiver; #[derive(Clone)] struct Shared where O: Send + Clone + 'static { incoming: TaskReceiver, output: OutputSender, } pub struct ThreadPool where O: Send + Clone + 'static { workers: Vec>, shared: Shared, tasks: TaskSender, results: OutputReceiver, } impl ThreadPool where O: Send + Clone + 'static { pub fn with_threads(size: usize) -> Self { let (tasks, incoming) = mpsc::channel(); let (output, results) = mpsc::channel(); let shared = Shared { incoming: Arc::new(Mutex::new(incoming)), output: Arc::new(Mutex::new(output)), }; let mut pool = Self { workers: Vec::new(), shared, tasks, results }; for worker_id in 0..size { pool.spawn_worker(worker_id); } pool } pub fn add_task(&self, task: impl Task + Send + 'static) { self.tasks.send(Box::new(task)).expect("Failed to enqueue task"); } pub fn next_result(&self) -> Result, mpsc::TryRecvError> { match self.results.try_recv() { Ok(output) => Ok(Some(output)), Err(mpsc::TryRecvError::Empty) => Ok(None), Err(e) => Err(e), } } fn spawn_worker(&mut self, worker_id: usize) { let shared = self.shared.clone(); let worker = thread::Builder::new() .name(format!("worker-{}", worker_id)) .spawn(move || { loop { let task = { let incoming = shared.incoming.lock().expect("Lock was poisoned"); match incoming.recv() { Ok(task) => task, Err(_) => break, } }; let output = task.process(); let output_sender = shared.output.lock().expect("Lock was poisoned"); if let Err(_) = output_sender.send(output) { break; } } }) .expect("Failed to spawn thread pool worker"); self.workers.push(worker); } }