118 lines
3.3 KiB
Rust
118 lines
3.3 KiB
Rust
use std::sync::atomic::{AtomicU8, Ordering};
|
|
use std::thread;
|
|
|
|
use clap::Parser;
|
|
use raw_sync::locks::*;
|
|
use shared_memory::*;
|
|
|
|
/// Spawns N threads that increment a value to 10 using a mutex
|
|
#[derive(Parser)]
|
|
#[clap(author, version, about)]
|
|
struct Args {
|
|
/// Number of threads to spawn
|
|
num_threads: usize,
|
|
|
|
/// Count to this value
|
|
#[clap(long, short, default_value_t = 50)]
|
|
count_to: u8,
|
|
}
|
|
|
|
fn main() {
|
|
env_logger::init();
|
|
let args = Args::parse();
|
|
|
|
if args.num_threads < 1 {
|
|
eprintln!("num_threads should be 2 or more");
|
|
return;
|
|
}
|
|
let mut threads = Vec::with_capacity(args.num_threads);
|
|
let _ = std::fs::remove_file("mutex_mapping");
|
|
|
|
// Spawn N threads
|
|
for i in 0..args.num_threads {
|
|
let thread_id = i + 1;
|
|
threads.push(thread::spawn(move || {
|
|
increment_value("mutex_mapping", thread_id);
|
|
}));
|
|
}
|
|
|
|
// Wait for threads to exit
|
|
for t in threads.drain(..) {
|
|
t.join().unwrap();
|
|
}
|
|
}
|
|
|
|
fn increment_value(shmem_flink: &str, thread_num: usize) {
|
|
// Create or open the shared memory mapping
|
|
let shmem = match ShmemConf::new().size(4096).flink(shmem_flink).create() {
|
|
Ok(m) => m,
|
|
Err(ShmemError::LinkExists) => ShmemConf::new().flink(shmem_flink).open().unwrap(),
|
|
Err(e) => {
|
|
eprintln!(
|
|
"Unable to create or open shmem flink {} : {}",
|
|
shmem_flink, e
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
|
|
let mut raw_ptr = shmem.as_ptr();
|
|
let is_init: &mut AtomicU8;
|
|
|
|
unsafe {
|
|
is_init = &mut *(raw_ptr as *mut u8 as *mut AtomicU8);
|
|
raw_ptr = raw_ptr.add(8);
|
|
};
|
|
|
|
// Initialize or wait for initialized mutex
|
|
let mutex = if shmem.is_owner() {
|
|
is_init.store(0, Ordering::Relaxed);
|
|
// Initialize the mutex
|
|
let (lock, _bytes_used) = unsafe {
|
|
Mutex::new(
|
|
raw_ptr, // Base address of Mutex
|
|
raw_ptr.add(Mutex::size_of(Some(raw_ptr))), // Address of data protected by mutex
|
|
)
|
|
.unwrap()
|
|
};
|
|
is_init.store(1, Ordering::Relaxed);
|
|
lock
|
|
} else {
|
|
// wait until mutex is initialized
|
|
while is_init.load(Ordering::Relaxed) != 1 {}
|
|
// Load existing mutex
|
|
let (lock, _bytes_used) = unsafe {
|
|
Mutex::from_existing(
|
|
raw_ptr, // Base address of Mutex
|
|
raw_ptr.add(Mutex::size_of(Some(raw_ptr))), // Address of data protected by mutex
|
|
)
|
|
.unwrap()
|
|
};
|
|
lock
|
|
};
|
|
|
|
// Loop until mutex data reaches 10
|
|
loop {
|
|
// Scope where mutex will be locked
|
|
{
|
|
let mut guard = mutex.lock().unwrap();
|
|
// Cast mutex data to &mut u8
|
|
let val: &mut u8 = unsafe { &mut **guard };
|
|
if *val > 5 {
|
|
println!("[thread#{}] done !", thread_num);
|
|
return;
|
|
}
|
|
|
|
// Print contents and increment value
|
|
println!("[thread#{}] Val : {}", thread_num, *val);
|
|
*val += 1;
|
|
|
|
// Hold lock for a second
|
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
|
}
|
|
|
|
// Timeout this thread for a second
|
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
|
}
|
|
}
|