Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 133 additions & 126 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cmp, convert::TryInto, env::current_dir, fs, path::PathBuf, str::FromStr, sync::Arc, thread, time::Instant};
use std::{cmp, convert::TryInto, env::current_dir, fs, path::PathBuf, process, str::FromStr, sync::Arc, thread, time::Instant};

use anyhow::{anyhow, Context as AnyContext};
use clap::Parser;
Expand Down Expand Up @@ -223,152 +223,159 @@ async fn main_inner() -> Result<(), anyhow::Error> {

let submit = true;

#[cfg(not(any(feature = "nvidia", feature = "opencl3")))] {
eprintln!("No GPU engine available");
process::exit(1);
}

#[cfg(feature = "nvidia")]
let mut gpu_engine = GpuEngine::new(CudaEngine::new());

#[cfg(feature = "opencl3")]
let mut gpu_engine = GpuEngine::new(OpenClEngine::new());

gpu_engine.init().unwrap();

// http server
let mut shutdown = Shutdown::new();
let (stats_tx, stats_rx) = tokio::sync::broadcast::channel(100);
if config.http_server_enabled {
let mut stats_collector = stats_collector::StatsCollector::new(shutdown.to_signal(), stats_rx);
let stats_client = stats_collector.create_client();
info!(target: LOG_TARGET, "Stats collector started");
tokio::spawn(async move {
stats_collector.run().await;
info!(target: LOG_TARGET, "Stats collector shutdown");
});
let http_server_config = Config::new(config.http_server_port);
info!(target: LOG_TARGET, "HTTP server runs on port: {}", &http_server_config.port);
let http_server = HttpServer::new(shutdown.to_signal(), http_server_config, stats_client);
info!(target: LOG_TARGET, "HTTP server enabled");
tokio::spawn(async move {
if let Err(error) = http_server.start().await {
println!("Failed to start HTTP server: {error:?}");
error!(target: LOG_TARGET, "Failed to start HTTP server: {:?}", error);
} else {
info!(target: LOG_TARGET, "Success to start HTTP server");
}
});
}
#[cfg(any(feature = "nvidia", feature = "opencl3"))] {
gpu_engine.init().unwrap();

// http server
let mut shutdown = Shutdown::new();
let (stats_tx, stats_rx) = tokio::sync::broadcast::channel(100);
if config.http_server_enabled {
let mut stats_collector = stats_collector::StatsCollector::new(shutdown.to_signal(), stats_rx);
let stats_client = stats_collector.create_client();
info!(target: LOG_TARGET, "Stats collector started");
tokio::spawn(async move {
stats_collector.run().await;
info!(target: LOG_TARGET, "Stats collector shutdown");
});
let http_server_config = Config::new(config.http_server_port);
info!(target: LOG_TARGET, "HTTP server runs on port: {}", &http_server_config.port);
let http_server = HttpServer::new(shutdown.to_signal(), http_server_config, stats_client);
info!(target: LOG_TARGET, "HTTP server enabled");
tokio::spawn(async move {
if let Err(error) = http_server.start().await {
println!("Failed to start HTTP server: {error:?}");
error!(target: LOG_TARGET, "Failed to start HTTP server: {:?}", error);
} else {
info!(target: LOG_TARGET, "Success to start HTTP server");
}
});
}

let num_devices = gpu_engine.num_devices()?;
let num_devices = gpu_engine.num_devices()?;

// just create the context to test if it can run
if let Some(_detect) = cli.detect {
let gpu = gpu_engine.clone();
let mut is_any_available = false;
// just create the context to test if it can run
if let Some(_detect) = cli.detect {
let gpu = gpu_engine.clone();
let mut is_any_available = false;

let mut gpu_devices = match gpu.detect_devices() {
Ok(gpu_stats) => gpu_stats,
Err(error) => {
warn!(target: LOG_TARGET, "No gpu device detected");
return Err(anyhow::anyhow!("Gpu detect error: {:?}", error));
},
};
if num_devices > 0 {
for i in 0..num_devices {
match gpu.create_context(i) {
Ok(_) => {
info!(target: LOG_TARGET, "Gpu detected. Created context for device nr: {:?}", i+1);
if let Some(gpstat) = gpu_devices.get_mut(i as usize) {
gpstat.is_available = true;
is_any_available = true;
}
},
Err(error) => {
warn!(target: LOG_TARGET, "Failed to create context for gpu device nr: {:?}", i+1);
continue;
},
let mut gpu_devices = match gpu.detect_devices() {
Ok(gpu_stats) => gpu_stats,
Err(error) => {
warn!(target: LOG_TARGET, "No gpu device detected");
return Err(anyhow::anyhow!("Gpu detect error: {:?}", error));
},
};
if num_devices > 0 {
for i in 0..num_devices {
match gpu.create_context(i) {
Ok(_) => {
info!(target: LOG_TARGET, "Gpu detected. Created context for device nr: {:?}", i+1);
if let Some(gpstat) = gpu_devices.get_mut(i as usize) {
gpstat.is_available = true;
is_any_available = true;
}
},
Err(error) => {
warn!(target: LOG_TARGET, "Failed to create context for gpu device nr: {:?}", i+1);
continue;
},
}
}
}
}

let status_file = GpuStatusFile::new(gpu_devices);
let default_path = {
let mut path = current_dir().expect("no current directory");
path.push("gpu_status.json");
path
};
let path = cli.gpu_status_file.unwrap_or_else(|| default_path.clone());

let _ = match GpuStatusFile::load(&path) {
Ok(_) => {
if let Err(err) = status_file.save(&path) {
warn!(target: LOG_TARGET,"Error saving gpu status: {}", err);
}
status_file
},
Err(err) => {
if let Err(err) = fs::create_dir_all(path.parent().expect("no parent")) {
warn!(target: LOG_TARGET, "Error creating directory: {}", err);
}
if let Err(err) = status_file.save(&path) {
warn!(target: LOG_TARGET,"Error saving gpu status: {}", err);
}
status_file
},
};
let status_file = GpuStatusFile::new(gpu_devices);
let default_path = {
let mut path = current_dir().expect("no current directory");
path.push("gpu_status.json");
path
};
let path = cli.gpu_status_file.unwrap_or_else(|| default_path.clone());

let _ = match GpuStatusFile::load(&path) {
Ok(_) => {
if let Err(err) = status_file.save(&path) {
warn!(target: LOG_TARGET,"Error saving gpu status: {}", err);
}
status_file
},
Err(err) => {
if let Err(err) = fs::create_dir_all(path.parent().expect("no parent")) {
warn!(target: LOG_TARGET, "Error creating directory: {}", err);
}
if let Err(err) = status_file.save(&path) {
warn!(target: LOG_TARGET,"Error saving gpu status: {}", err);
}
status_file
},
};

if is_any_available {
return Ok(());
if is_any_available {
return Ok(());
}
return Err(anyhow::anyhow!("No available gpu device detected"));
}
return Err(anyhow::anyhow!("No available gpu device detected"));
}

// create a list of devices (by index) to use
let devices_to_use: Vec<u32> = (0..num_devices)
.filter(|x| {
if let Some(use_devices) = &cli.use_devices {
use_devices.contains(x)
} else {
true
}
})
.filter(|x| {
if let Some(excluded_devices) = &cli.exclude_devices {
!excluded_devices.contains(x)
} else {
true
// create a list of devices (by index) to use
let devices_to_use: Vec<u32> = (0..num_devices)
.filter(|x| {
if let Some(use_devices) = &cli.use_devices {
use_devices.contains(x)
} else {
true
}
})
.filter(|x| {
if let Some(excluded_devices) = &cli.exclude_devices {
!excluded_devices.contains(x)
} else {
true
}
})
.collect();

info!(target: LOG_TARGET, "Device indexes to use: {:?} from the total number of devices: {:?}", devices_to_use, num_devices);

let mut threads = vec![];
for i in 0..num_devices {
if devices_to_use.contains(&i) {
let c = config.clone();
let gpu = gpu_engine.clone();
let curr_stats_tx = stats_tx.clone();
threads.push(thread::spawn(move || {
run_thread(gpu, num_devices as u64, i as u32, c, benchmark, curr_stats_tx)
}));
}
})
.collect();

info!(target: LOG_TARGET, "Device indexes to use: {:?} from the total number of devices: {:?}", devices_to_use, num_devices);

let mut threads = vec![];
for i in 0..num_devices {
if devices_to_use.contains(&i) {
let c = config.clone();
let gpu = gpu_engine.clone();
let curr_stats_tx = stats_tx.clone();
threads.push(thread::spawn(move || {
run_thread(gpu, num_devices as u64, i as u32, c, benchmark, curr_stats_tx)
}));
}
}

// for t in threads {
// t.join().unwrap()?;
// }
// let mut res = Ok(());
for t in threads {
if let Err(err) = t.join() {
error!(target: LOG_TARGET, "Thread join failed: {:?}", err);
// if res.is_ok() {
// res = Err(anyhow!(err));
// }
// err?;
// for t in threads {
// t.join().unwrap()?;
// }
// let mut res = Ok(());
for t in threads {
if let Err(err) = t.join() {
error!(target: LOG_TARGET, "Thread join failed: {:?}", err);
// if res.is_ok() {
// res = Err(anyhow!(err));
// }
// err?;
}
}
}

shutdown.trigger();
shutdown.trigger();

Ok(())
Ok(())
}
}

fn run_thread<T: EngineImpl>(
Expand Down