108 lines
2.8 KiB
Rust
108 lines
2.8 KiB
Rust
use std::{
|
|
thread,
|
|
sync::{Arc, Mutex, mpsc},
|
|
};
|
|
|
|
pub trait Task<O>
|
|
where O: Send + 'static
|
|
{
|
|
fn process(&self) -> O;
|
|
}
|
|
|
|
impl<F, O> Task<O> for F
|
|
where F: Fn() -> O,
|
|
F: Send + 'static,
|
|
O: Send + 'static
|
|
{
|
|
fn process(&self) -> O {
|
|
self()
|
|
}
|
|
}
|
|
|
|
type BoxTask<O> = Box<dyn Task<O> + Send + 'static>;
|
|
type TaskReceiver<O> = Arc<Mutex<mpsc::Receiver<BoxTask<O>>>>;
|
|
type TaskSender<O> = mpsc::Sender<BoxTask<O>>;
|
|
type OutputSender<O> = Arc<Mutex<mpsc::Sender<O>>>;
|
|
type OutputReceiver<O> = mpsc::Receiver<O>;
|
|
|
|
#[derive(Clone)]
|
|
struct Shared<O>
|
|
where O: Send + Clone + 'static
|
|
{
|
|
incoming: TaskReceiver<O>,
|
|
output: OutputSender<O>,
|
|
}
|
|
|
|
pub struct ThreadPool<O>
|
|
where O: Send + Clone + 'static
|
|
{
|
|
workers: Vec<thread::JoinHandle<()>>,
|
|
shared: Shared<O>,
|
|
tasks: TaskSender<O>,
|
|
results: OutputReceiver<O>,
|
|
}
|
|
|
|
impl<O> ThreadPool<O>
|
|
where O: Send + Clone + 'static
|
|
{
|
|
pub fn with_threads(size: usize) -> Self {
|
|
let (tasks, incoming) = mpsc::channel();
|
|
let (output, results) = mpsc::channel();
|
|
let shared = Shared {
|
|
incoming: Arc::new(Mutex::new(incoming)),
|
|
output: Arc::new(Mutex::new(output)),
|
|
};
|
|
let mut pool = Self {
|
|
workers: Vec::new(),
|
|
shared,
|
|
tasks,
|
|
results
|
|
};
|
|
|
|
for worker_id in 0..size {
|
|
pool.spawn_worker(worker_id);
|
|
}
|
|
|
|
pool
|
|
}
|
|
|
|
pub fn add_task(&self, task: impl Task<O> + Send + 'static) {
|
|
self.tasks.send(Box::new(task)).expect("Failed to enqueue task");
|
|
}
|
|
|
|
pub fn next_result(&self) -> Result<Option<O>, mpsc::TryRecvError> {
|
|
match self.results.try_recv() {
|
|
Ok(output) => Ok(Some(output)),
|
|
Err(mpsc::TryRecvError::Empty) => Ok(None),
|
|
Err(e) => Err(e),
|
|
}
|
|
}
|
|
|
|
fn spawn_worker(&mut self, worker_id: usize) {
|
|
let shared = self.shared.clone();
|
|
let worker = thread::Builder::new()
|
|
.name(format!("worker-{}", worker_id))
|
|
.spawn(move || {
|
|
loop {
|
|
let task = {
|
|
let incoming = shared.incoming.lock().expect("Lock was poisoned");
|
|
match incoming.recv() {
|
|
Ok(task) => task,
|
|
Err(_) => break,
|
|
}
|
|
};
|
|
|
|
let output = task.process();
|
|
|
|
let output_sender = shared.output.lock().expect("Lock was poisoned");
|
|
if let Err(_) = output_sender.send(output) {
|
|
break;
|
|
}
|
|
}
|
|
})
|
|
.expect("Failed to spawn thread pool worker");
|
|
|
|
self.workers.push(worker);
|
|
}
|
|
}
|