diff --git a/src/fd/mod.rs b/src/fd/mod.rs index 7c8096a411..5629a37709 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -233,25 +233,27 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { /// `accept` a connection on a socket #[cfg(any(feature = "net", feature = "vsock"))] - async fn accept(&self) -> io::Result<(Arc, Endpoint)> { + async fn accept( + &mut self, + ) -> io::Result<(Arc>, Endpoint)> { Err(Errno::Inval) } /// initiate a connection on a socket #[cfg(any(feature = "net", feature = "vsock"))] - async fn connect(&self, _endpoint: Endpoint) -> io::Result<()> { + async fn connect(&mut self, _endpoint: Endpoint) -> io::Result<()> { Err(Errno::Inval) } /// `bind` a name to a socket #[cfg(any(feature = "net", feature = "vsock"))] - async fn bind(&self, _name: ListenEndpoint) -> io::Result<()> { + async fn bind(&mut self, _name: ListenEndpoint) -> io::Result<()> { Err(Errno::Inval) } /// `listen` for connections on a socket #[cfg(any(feature = "net", feature = "vsock"))] - async fn listen(&self, _backlog: i32) -> io::Result<()> { + async fn listen(&mut self, _backlog: i32) -> io::Result<()> { Err(Errno::Inval) } @@ -310,7 +312,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { } /// Sets the file status flags. - async fn set_status_flags(&self, _status_flags: StatusFlags) -> io::Result<()> { + async fn set_status_flags(&mut self, _status_flags: StatusFlags) -> io::Result<()> { Err(Errno::Nosys) } @@ -337,19 +339,19 @@ pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result { return Ok(0); } - block_on(obj.read(buf), None) + block_on(async { obj.read().await.read(buf).await }, None) } pub(crate) fn lseek(fd: FileDescriptor, offset: isize, whence: SeekWhence) -> io::Result { let obj = get_object(fd)?; - block_on(obj.lseek(offset, whence), None) + block_on(async { obj.read().await.lseek(offset, whence).await }, None) } pub(crate) fn chmod(fd: FileDescriptor, mode: AccessPermission) -> io::Result<()> { let obj = get_object(fd)?; - block_on(obj.chmod(mode), None) + block_on(async { obj.read().await.chmod(mode).await }, None) } pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result { @@ -359,12 +361,12 @@ pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result { return Ok(0); } - block_on(obj.write(buf), None) + block_on(async { obj.read().await.write(buf).await }, None) } pub(crate) fn truncate(fd: FileDescriptor, length: usize) -> io::Result<()> { let obj = get_object(fd)?; - block_on(obj.truncate(length), None) + block_on(async { obj.read().await.truncate(length).await }, None) } async fn poll_fds(fds: &mut [PollFd]) -> io::Result { @@ -375,7 +377,7 @@ async fn poll_fds(fds: &mut [PollFd]) -> io::Result { let fd = i.fd; i.revents = PollEvent::empty(); if let Ok(obj) = core_scheduler().get_object(fd) { - let mut pinned = core::pin::pin!(obj.poll(i.events)); + let mut pinned = core::pin::pin!(async { obj.read().await.poll(i.events).await }); if let Ready(Ok(e)) = pinned.as_mut().poll(cx) && !e.is_empty() { @@ -416,7 +418,7 @@ pub fn poll(fds: &mut [PollFd], timeout: Option) -> io::Result { pub fn fstat(fd: FileDescriptor) -> io::Result { let obj = get_object(fd)?; - block_on(obj.fstat(), None) + block_on(async { obj.read().await.fstat().await }, None) } /// Wait for some event on a file descriptor. @@ -440,16 +442,20 @@ pub fn fstat(fd: FileDescriptor) -> io::Result { pub fn eventfd(initval: u64, flags: EventFlags) -> io::Result { let obj = self::eventfd::EventFd::new(initval, flags); - let fd = core_scheduler().insert_object(Arc::new(obj))?; + let fd = core_scheduler().insert_object(Arc::new(async_lock::RwLock::new(obj)))?; Ok(fd) } -pub(crate) fn get_object(fd: FileDescriptor) -> io::Result> { +pub(crate) fn get_object( + fd: FileDescriptor, +) -> io::Result>> { core_scheduler().get_object(fd) } -pub(crate) fn insert_object(obj: Arc) -> io::Result { +pub(crate) fn insert_object( + obj: Arc>, +) -> io::Result { core_scheduler().insert_object(obj) } @@ -465,11 +471,13 @@ pub(crate) fn dup_object2(fd1: FileDescriptor, fd2: FileDescriptor) -> io::Resul core_scheduler().dup_object2(fd1, fd2) } -pub(crate) fn remove_object(fd: FileDescriptor) -> io::Result> { +pub(crate) fn remove_object( + fd: FileDescriptor, +) -> io::Result>> { core_scheduler().remove_object(fd) } pub(crate) fn isatty(fd: FileDescriptor) -> io::Result { let obj = get_object(fd)?; - block_on(obj.isatty(), None) + block_on(async { obj.read().await.isatty().await }, None) } diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index af3dee8d47..e432f7abd7 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -110,7 +110,10 @@ impl Socket { }) .await } +} +#[async_trait] +impl ObjectInterface for Socket { async fn poll(&self, event: PollEvent) -> io::Result { future::poll_fn(|cx| { self.with(|socket| match socket.state() { @@ -275,7 +278,7 @@ impl Socket { } } - async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> { #[allow(irrefutable_let_patterns)] if let Endpoint::Ip(endpoint) = endpoint { self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port())) @@ -298,7 +301,9 @@ impl Socket { } } - async fn accept(&mut self) -> io::Result<(Socket, Endpoint)> { + async fn accept( + &mut self, + ) -> io::Result<(Arc>, Endpoint)> { if !self.is_listen { self.listen(DEFAULT_BACKLOG).await?; } @@ -357,7 +362,7 @@ impl Socket { is_listen: false, }; - Ok((socket, endpoint)) + Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint)) } async fn getpeername(&self) -> io::Result> { @@ -473,63 +478,3 @@ impl Drop for Socket { } } } - -#[async_trait] -impl ObjectInterface for async_lock::RwLock { - async fn poll(&self, event: PollEvent) -> io::Result { - self.read().await.poll(event).await - } - - async fn read(&self, buffer: &mut [u8]) -> io::Result { - self.read().await.read(buffer).await - } - - async fn write(&self, buffer: &[u8]) -> io::Result { - self.read().await.write(buffer).await - } - - async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { - self.write().await.bind(endpoint).await - } - - async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { - self.read().await.connect(endpoint).await - } - - async fn accept(&self) -> io::Result<(Arc, Endpoint)> { - let (socket, endpoint) = self.write().await.accept().await?; - Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint)) - } - - async fn getpeername(&self) -> io::Result> { - self.read().await.getpeername().await - } - - async fn getsockname(&self) -> io::Result> { - self.read().await.getsockname().await - } - - async fn listen(&self, backlog: i32) -> io::Result<()> { - self.write().await.listen(backlog).await - } - - async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> { - self.read().await.setsockopt(opt, optval).await - } - - async fn getsockopt(&self, opt: SocketOption) -> io::Result { - self.read().await.getsockopt(opt).await - } - - async fn shutdown(&self, how: i32) -> io::Result<()> { - self.read().await.shutdown(how).await - } - - async fn status_flags(&self) -> io::Result { - self.read().await.status_flags().await - } - - async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> { - self.write().await.set_status_flags(status_flags).await - } -} diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index 2480a446c5..8fb4975d44 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -74,7 +74,10 @@ impl Socket { }) .await } +} +#[async_trait] +impl ObjectInterface for Socket { async fn poll(&self, event: PollEvent) -> io::Result { future::poll_fn(|cx| { self.with(|socket| { @@ -251,46 +254,3 @@ impl Drop for Socket { NIC.lock().as_nic_mut().unwrap().destroy_socket(self.handle); } } - -#[async_trait] -impl ObjectInterface for async_lock::RwLock { - async fn poll(&self, event: PollEvent) -> io::Result { - self.read().await.poll(event).await - } - - async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { - self.write().await.bind(endpoint).await - } - - async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { - self.write().await.connect(endpoint).await - } - - async fn sendto(&self, buffer: &[u8], endpoint: Endpoint) -> io::Result { - self.read().await.sendto(buffer, endpoint).await - } - - async fn recvfrom(&self, buffer: &mut [MaybeUninit]) -> io::Result<(usize, Endpoint)> { - self.read().await.recvfrom(buffer).await - } - - async fn read(&self, buffer: &mut [u8]) -> io::Result { - self.read().await.read(buffer).await - } - - async fn write(&self, buf: &[u8]) -> io::Result { - self.read().await.write(buf).await - } - - async fn getsockname(&self) -> io::Result> { - self.read().await.getsockname().await - } - - async fn status_flags(&self) -> io::Result { - self.read().await.status_flags().await - } - - async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> { - self.write().await.set_status_flags(status_flags).await - } -} diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 5b9a59de26..80a36cd883 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -51,7 +51,7 @@ impl NullSocket { } #[async_trait] -impl ObjectInterface for async_lock::RwLock {} +impl ObjectInterface for NullSocket {} #[derive(Debug)] pub struct Socket { @@ -68,7 +68,10 @@ impl Socket { is_nonblocking: false, } } +} +#[async_trait] +impl ObjectInterface for Socket { async fn poll(&self, event: PollEvent) -> io::Result { future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); @@ -232,11 +235,13 @@ impl Socket { )))) } - async fn listen(&self, _backlog: i32) -> io::Result<()> { + async fn listen(&mut self, _backlog: i32) -> io::Result<()> { Ok(()) } - async fn accept(&mut self) -> io::Result<(NullSocket, Endpoint)> { + async fn accept( + &mut self, + ) -> io::Result<(Arc>, Endpoint)> { let port = self.port; let cid = self.cid; @@ -292,7 +297,10 @@ impl Socket { }) .await?; - Ok((NullSocket::new(), Endpoint::Vsock(endpoint))) + Ok(( + Arc::new(async_lock::RwLock::new(NullSocket::new())), + Endpoint::Vsock(endpoint), + )) } async fn shutdown(&self, _how: i32) -> io::Result<()> { @@ -419,55 +427,3 @@ impl Drop for Socket { guard.remove_socket(self.port); } } - -#[async_trait] -impl ObjectInterface for async_lock::RwLock { - async fn poll(&self, event: PollEvent) -> io::Result { - self.read().await.poll(event).await - } - - async fn read(&self, buffer: &mut [u8]) -> io::Result { - self.read().await.read(buffer).await - } - - async fn write(&self, buffer: &[u8]) -> io::Result { - self.read().await.write(buffer).await - } - - async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { - self.write().await.bind(endpoint).await - } - - async fn connect(&self, endpoint: Endpoint) -> io::Result<()> { - self.write().await.connect(endpoint).await - } - - async fn accept(&self) -> io::Result<(Arc, Endpoint)> { - let (handle, endpoint) = self.write().await.accept().await?; - Ok((Arc::new(async_lock::RwLock::new(handle)), endpoint)) - } - - async fn getpeername(&self) -> io::Result> { - self.read().await.getpeername().await - } - - async fn getsockname(&self) -> io::Result> { - self.read().await.getsockname().await - } - - async fn listen(&self, backlog: i32) -> io::Result<()> { - self.write().await.listen(backlog).await - } - - async fn shutdown(&self, how: i32) -> io::Result<()> { - self.read().await.shutdown(how).await - } - - async fn status_flags(&self) -> io::Result { - self.read().await.status_flags().await - } - - async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> { - self.write().await.set_status_flags(status_flags).await - } -} diff --git a/src/fs/fuse.rs b/src/fs/fuse.rs index 984c02f7f6..065b0cb414 100644 --- a/src/fs/fuse.rs +++ b/src/fs/fuse.rs @@ -1123,8 +1123,10 @@ impl VfsNode for FuseDirectory { Ok(self.attr) } - fn get_object(&self) -> io::Result> { - Ok(Arc::new(FuseDirectoryHandle::new(self.prefix.clone()))) + fn get_object(&self) -> io::Result>> { + Ok(Arc::new(async_lock::RwLock::new(FuseDirectoryHandle::new( + self.prefix.clone(), + )))) } fn traverse_readdir(&self, components: &mut Vec<&str>) -> io::Result> { @@ -1253,7 +1255,7 @@ impl VfsNode for FuseDirectory { components: &mut Vec<&str>, opt: OpenOption, mode: AccessPermission, - ) -> io::Result> { + ) -> io::Result>> { let path = self.traversal_path(components); debug!("FUSE open: {path:#?}, {opt:?} {mode:?}"); @@ -1275,7 +1277,9 @@ impl VfsNode for FuseDirectory { if attr.st_mode.contains(AccessPermission::S_IFDIR) { let mut path = path.into_string().unwrap(); path.remove(0); - Ok(Arc::new(FuseDirectoryHandle::new(Some(path)))) + Ok(Arc::new(async_lock::RwLock::new(FuseDirectoryHandle::new( + Some(path), + )))) } else { Err(Errno::Notdir) } @@ -1320,7 +1324,7 @@ impl VfsNode for FuseDirectory { drop(file_guard); - Ok(Arc::new(file)) + Ok(Arc::new(async_lock::RwLock::new(file))) } } diff --git a/src/fs/mem.rs b/src/fs/mem.rs index b4fd4db443..d6dea3d2e6 100644 --- a/src/fs/mem.rs +++ b/src/fs/mem.rs @@ -294,8 +294,10 @@ impl VfsNode for RomFile { NodeKind::File } - fn get_object(&self) -> io::Result> { - Ok(Arc::new(RomFileInterface::new(self.data.clone()))) + fn get_object(&self) -> io::Result>> { + Ok(Arc::new(async_lock::RwLock::new(RomFileInterface::new( + self.data.clone(), + )))) } fn get_file_attributes(&self) -> io::Result { @@ -348,8 +350,10 @@ impl VfsNode for RamFile { NodeKind::File } - fn get_object(&self) -> io::Result> { - Ok(Arc::new(RamFileInterface::new(self.data.clone()))) + fn get_object(&self) -> io::Result>> { + Ok(Arc::new(async_lock::RwLock::new(RamFileInterface::new( + self.data.clone(), + )))) } fn get_file_attributes(&self) -> io::Result { @@ -502,7 +506,7 @@ impl MemDirectory { components: &mut Vec<&str>, opt: OpenOption, mode: AccessPermission, - ) -> io::Result> { + ) -> io::Result>> { if let Some(component) = components.pop() { let node_name = String::from(component); @@ -523,7 +527,9 @@ impl MemDirectory { } else if opt.contains(OpenOption::O_CREAT) { let file = Box::new(RamFile::new(mode)); guard.insert(node_name, file.clone()); - return Ok(Arc::new(RamFileInterface::new(file.data.clone()))); + return Ok(Arc::new(async_lock::RwLock::new(RamFileInterface::new( + file.data.clone(), + )))); } else { return Err(Errno::Noent); } @@ -543,8 +549,10 @@ impl VfsNode for MemDirectory { NodeKind::Directory } - fn get_object(&self) -> io::Result> { - Ok(Arc::new(MemDirectoryInterface::new(self.inner.clone()))) + fn get_object(&self) -> io::Result>> { + Ok(Arc::new(async_lock::RwLock::new( + MemDirectoryInterface::new(self.inner.clone()), + ))) } fn get_file_attributes(&self) -> io::Result { @@ -735,7 +743,7 @@ impl VfsNode for MemDirectory { components: &mut Vec<&str>, opt: OpenOption, mode: AccessPermission, - ) -> io::Result> { + ) -> io::Result>> { block_on(self.async_traverse_open(components, opt, mode), None) } diff --git a/src/fs/mod.rs b/src/fs/mod.rs index 039e221566..ac5dc44e10 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -59,7 +59,7 @@ pub(crate) trait VfsNode: core::fmt::Debug { } /// Determine the syscall interface - fn get_object(&self) -> io::Result> { + fn get_object(&self) -> io::Result>> { Err(Errno::Nosys) } @@ -112,7 +112,7 @@ pub(crate) trait VfsNode: core::fmt::Debug { _components: &mut Vec<&str>, _option: OpenOption, _mode: AccessPermission, - ) -> io::Result> { + ) -> io::Result>> { Err(Errno::Nosys) } @@ -162,7 +162,7 @@ impl Filesystem { path: &str, opt: OpenOption, mode: AccessPermission, - ) -> io::Result> { + ) -> io::Result>> { debug!("Open file {path} with {opt:?}"); let mut components: Vec<&str> = path.split('/').collect(); @@ -205,9 +205,11 @@ impl Filesystem { self.root.traverse_mkdir(&mut components, mode) } - pub fn opendir(&self, path: &str) -> io::Result> { + pub fn opendir(&self, path: &str) -> io::Result>> { debug!("Open directory {path}"); - Ok(Arc::new(DirectoryReader::new(self.readdir(path)?))) + Ok(Arc::new(async_lock::RwLock::new(DirectoryReader::new( + self.readdir(path)?, + )))) } /// List given directory @@ -447,7 +449,7 @@ pub fn truncate(name: &str, size: usize) -> io::Result<()> { with_relative_filename(name, |name| { let fs = FILESYSTEM.get().ok_or(Errno::Inval)?; if let Ok(file) = fs.open(name, OpenOption::O_TRUNC, AccessPermission::empty()) { - block_on(file.truncate(size), None) + block_on(async { file.read().await.truncate(size).await }, None) } else { Err(Errno::Badf) } diff --git a/src/fs/uhyve.rs b/src/fs/uhyve.rs index 2520be008f..d180497f65 100644 --- a/src/fs/uhyve.rs +++ b/src/fs/uhyve.rs @@ -171,7 +171,7 @@ impl VfsNode for UhyveDirectory { components: &mut Vec<&str>, opt: OpenOption, mode: AccessPermission, - ) -> io::Result> { + ) -> io::Result>> { let path = self.traversal_path(components); let mut open_params = OpenParams { @@ -187,7 +187,9 @@ impl VfsNode for UhyveDirectory { uhyve_hypercall(Hypercall::FileOpen(&mut open_params)); if open_params.ret > 0 { - Ok(Arc::new(UhyveFileHandle::new(open_params.ret))) + Ok(Arc::new(async_lock::RwLock::new(UhyveFileHandle::new( + open_params.ret, + )))) } else { Err(Errno::Io) } diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index e297bb044c..21f2fb5da1 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -220,7 +220,11 @@ struct NewTask { prio: Priority, core_id: CoreId, stacks: TaskStacks, - object_map: Arc, RandomState>>>, + object_map: Arc< + RwSpinLock< + HashMap>, RandomState>, + >, + >, } impl From for Task { @@ -457,14 +461,21 @@ impl PerCoreScheduler { #[inline] pub fn get_current_task_object_map( &self, - ) -> Arc, RandomState>>> { + ) -> Arc< + RwSpinLock< + HashMap>, RandomState>, + >, + > { without_interrupts(|| self.current_task.borrow().object_map.clone()) } /// Map a file descriptor to their IO interface and returns /// the shared reference #[inline] - pub fn get_object(&self, fd: FileDescriptor) -> io::Result> { + pub fn get_object( + &self, + fd: FileDescriptor, + ) -> io::Result>> { without_interrupts(|| { let current_task = self.current_task.borrow(); let object_map = current_task.object_map.read(); @@ -477,9 +488,11 @@ impl PerCoreScheduler { #[cfg(feature = "common-os")] #[cfg_attr(not(target_arch = "x86_64"), expect(dead_code))] pub fn recreate_objmap(&self) -> io::Result<()> { - let mut map = HashMap::, RandomState>::with_hasher( - RandomState::with_seeds(0, 0, 0, 0), - ); + let mut map = HashMap::< + FileDescriptor, + Arc>, + RandomState, + >::with_hasher(RandomState::with_seeds(0, 0, 0, 0)); without_interrupts(|| { let mut current_task = self.current_task.borrow_mut(); @@ -501,7 +514,10 @@ impl PerCoreScheduler { /// Insert a new IO interface and returns a file descriptor as /// identifier to this object - pub fn insert_object(&self, obj: Arc) -> io::Result { + pub fn insert_object( + &self, + obj: Arc>, + ) -> io::Result { without_interrupts(|| { let current_task = self.current_task.borrow(); let mut object_map = current_task.object_map.write(); @@ -576,7 +592,10 @@ impl PerCoreScheduler { } /// Remove a IO interface, which is named by the file descriptor - pub fn remove_object(&self, fd: FileDescriptor) -> io::Result> { + pub fn remove_object( + &self, + fd: FileDescriptor, + ) -> io::Result>> { without_interrupts(|| { let current_task = self.current_task.borrow(); let mut object_map = current_task.object_map.write(); diff --git a/src/scheduler/task.rs b/src/scheduler/task.rs index 2b04425b4c..78d523b757 100644 --- a/src/scheduler/task.rs +++ b/src/scheduler/task.rs @@ -390,7 +390,11 @@ pub(crate) struct Task { /// Stack of the task pub stacks: TaskStacks, /// Mapping between file descriptor and the referenced IO interface - pub object_map: Arc, RandomState>>>, + pub object_map: Arc< + RwSpinLock< + HashMap>, RandomState>, + >, + >, /// Task Thread-Local-Storage (TLS) #[cfg(not(feature = "common-os"))] pub tls: Option>, @@ -411,7 +415,11 @@ impl Task { task_status: TaskStatus, task_prio: Priority, stacks: TaskStacks, - object_map: Arc, RandomState>>>, + object_map: Arc< + RwSpinLock< + HashMap>, RandomState>, + >, + >, ) -> Task { debug!("Creating new task {tid} on core {core_id}"); @@ -437,14 +445,22 @@ impl Task { /// All cores use the same mapping between file descriptor and the referenced object static OBJECT_MAP: OnceCell< - Arc, RandomState>>>, + Arc< + RwSpinLock< + HashMap< + FileDescriptor, + Arc>, + RandomState, + >, + >, + >, > = OnceCell::new(); if core_id == 0 { OBJECT_MAP .set(Arc::new(RwSpinLock::new(HashMap::< FileDescriptor, - Arc, + Arc>, RandomState, >::with_hasher( RandomState::with_seeds(0, 0, 0, 0), @@ -455,23 +471,41 @@ impl Task { let mut guard = objmap.write(); if env::is_uhyve() { guard - .try_insert(STDIN_FILENO, Arc::new(UhyveStdin::new())) + .try_insert( + STDIN_FILENO, + Arc::new(async_lock::RwLock::new(UhyveStdin::new())), + ) .map_err(|_| Errno::Io)?; guard - .try_insert(STDOUT_FILENO, Arc::new(UhyveStdout::new())) + .try_insert( + STDOUT_FILENO, + Arc::new(async_lock::RwLock::new(UhyveStdout::new())), + ) .map_err(|_| Errno::Io)?; guard - .try_insert(STDERR_FILENO, Arc::new(UhyveStderr::new())) + .try_insert( + STDERR_FILENO, + Arc::new(async_lock::RwLock::new(UhyveStderr::new())), + ) .map_err(|_| Errno::Io)?; } else { guard - .try_insert(STDIN_FILENO, Arc::new(GenericStdin::new())) + .try_insert( + STDIN_FILENO, + Arc::new(async_lock::RwLock::new(GenericStdin::new())), + ) .map_err(|_| Errno::Io)?; guard - .try_insert(STDOUT_FILENO, Arc::new(GenericStdout::new())) + .try_insert( + STDOUT_FILENO, + Arc::new(async_lock::RwLock::new(GenericStdout::new())), + ) .map_err(|_| Errno::Io)?; guard - .try_insert(STDERR_FILENO, Arc::new(GenericStderr::new())) + .try_insert( + STDERR_FILENO, + Arc::new(async_lock::RwLock::new(GenericStderr::new())), + ) .map_err(|_| Errno::Io)?; } diff --git a/src/syscalls/mod.rs b/src/syscalls/mod.rs index 538316ad69..7a2204509d 100644 --- a/src/syscalls/mod.rs +++ b/src/syscalls/mod.rs @@ -641,8 +641,11 @@ pub unsafe extern "C" fn sys_ioctl( obj.map_or_else( |e| -i32::from(e), |v| { - block_on((*v).set_status_flags(status_flags), None) - .map_or_else(|e| -i32::from(e), |()| 0) + block_on( + async { v.write().await.set_status_flags(status_flags).await }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) }, ) } else { @@ -666,7 +669,7 @@ pub extern "C" fn sys_fcntl(fd: i32, cmd: i32, arg: i32) -> i32 { obj.map_or_else( |e| -i32::from(e), |v| { - block_on((*v).status_flags(), None) + block_on(async { v.read().await.status_flags().await }, None) .map_or_else(|e| -i32::from(e), |status_flags| status_flags.bits()) }, ) @@ -676,7 +679,12 @@ pub extern "C" fn sys_fcntl(fd: i32, cmd: i32, arg: i32) -> i32 { |e| -i32::from(e), |v| { block_on( - (*v).set_status_flags(fd::StatusFlags::from_bits_retain(arg)), + async { + v.write() + .await + .set_status_flags(fd::StatusFlags::from_bits_retain(arg)) + .await + }, None, ) .map_or_else(|e| -i32::from(e), |()| 0) @@ -787,7 +795,7 @@ pub unsafe extern "C" fn sys_getdents64( obj.map_or_else( |_| (-i32::from(Errno::Inval)).into(), |v| { - block_on((*v).getdents(slice), None) + block_on(async { v.read().await.getdents(slice).await }, None) .map_or_else(|e| (-i32::from(e)).into(), |cnt| cnt as i64) }, ) diff --git a/src/syscalls/socket/mod.rs b/src/syscalls/socket/mod.rs index 37c6d538c9..ef112c04ad 100644 --- a/src/syscalls/socket/mod.rs +++ b/src/syscalls/socket/mod.rs @@ -588,12 +588,13 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 { #[cfg(feature = "vsock")] if domain == Af::Vsock && sock == Sock::Stream { - let socket = Arc::new(async_lock::RwLock::new(vsock::Socket::new())); + let mut socket = vsock::Socket::new(); if sock_flags.contains(SockFlags::SOCK_NONBLOCK) { block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap(); } + let socket = Arc::new(async_lock::RwLock::new(socket)); let fd = insert_object(socket).expect("FD is already used"); return fd; @@ -609,12 +610,13 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 { if sock == Sock::Dgram { let handle = nic.create_udp_handle().unwrap(); drop(guard); - let socket = Arc::new(async_lock::RwLock::new(udp::Socket::new(handle, domain))); + let mut socket = udp::Socket::new(handle, domain); if sock_flags.contains(SockFlags::SOCK_NONBLOCK) { block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap(); } + let socket = Arc::new(async_lock::RwLock::new(socket)); let fd = insert_object(socket).expect("FD is already used"); return fd; @@ -624,12 +626,13 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 { if sock == Sock::Stream { let handle = nic.create_tcp_handle().unwrap(); drop(guard); - let socket = Arc::new(async_lock::RwLock::new(tcp::Socket::new(handle, domain))); + let mut socket = tcp::Socket::new(handle, domain); if sock_flags.contains(SockFlags::SOCK_NONBLOCK) { block_on(socket.set_status_flags(fd::StatusFlags::O_NONBLOCK), None).unwrap(); } + let socket = Arc::new(async_lock::RwLock::new(socket)); let fd = insert_object(socket).expect("FD is already used"); return fd; @@ -647,7 +650,7 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut obj.map_or_else( |e| -i32::from(e), |v| { - block_on((*v).accept(), None).map_or_else( + block_on(async { v.write().await.accept().await }, None).map_or_else( |e| -i32::from(e), #[cfg_attr(not(feature = "net"), expect(unused_variables))] |(obj, endpoint)| match endpoint { @@ -708,7 +711,10 @@ pub extern "C" fn sys_listen(fd: i32, backlog: i32) -> i32 { let obj = get_object(fd); obj.map_or_else( |e| -i32::from(e), - |v| block_on((*v).listen(backlog), None).map_or_else(|e| -i32::from(e), |()| 0), + |v| { + block_on(async { v.write().await.listen(backlog).await }, None) + .map_or_else(|e| -i32::from(e), |()| 0) + }, ) } @@ -733,8 +739,11 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl return -i32::from(Errno::Inval); } let endpoint = IpListenEndpoint::from(unsafe { *name.cast::() }); - block_on((*v).bind(ListenEndpoint::Ip(endpoint)), None) - .map_or_else(|e| -i32::from(e), |()| 0) + block_on( + async { v.write().await.bind(ListenEndpoint::Ip(endpoint)).await }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) } #[cfg(feature = "net")] Af::Inet6 => { @@ -742,8 +751,11 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl return -i32::from(Errno::Inval); } let endpoint = IpListenEndpoint::from(unsafe { *name.cast::() }); - block_on((*v).bind(ListenEndpoint::Ip(endpoint)), None) - .map_or_else(|e| -i32::from(e), |()| 0) + block_on( + async { v.write().await.bind(ListenEndpoint::Ip(endpoint)).await }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) } #[cfg(feature = "vsock")] Af::Vsock => { @@ -751,8 +763,11 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl return -i32::from(Errno::Inval); } let endpoint = VsockListenEndpoint::from(unsafe { *name.cast::() }); - block_on((*v).bind(ListenEndpoint::Vsock(endpoint)), None) - .map_or_else(|e| -i32::from(e), |()| 0) + block_on( + async { v.write().await.bind(ListenEndpoint::Vsock(endpoint)).await }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) } _ => -i32::from(Errno::Inval), }, @@ -800,7 +815,10 @@ pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: so let obj = get_object(fd); obj.map_or_else( |e| -i32::from(e), - |v| block_on((*v).connect(endpoint), None).map_or_else(|e| -i32::from(e), |()| 0), + |v| { + block_on(async { v.write().await.connect(endpoint).await }, None) + .map_or_else(|e| -i32::from(e), |()| 0) + }, ) } @@ -815,7 +833,8 @@ pub unsafe extern "C" fn sys_getsockname( obj.map_or_else( |e| -i32::from(e), |v| { - if let Ok(Some(endpoint)) = block_on((*v).getsockname(), None) { + if let Ok(Some(endpoint)) = block_on(async { v.read().await.getsockname().await }, None) + { if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; @@ -898,8 +917,16 @@ pub unsafe extern "C" fn sys_setsockopt( obj.map_or_else( |e| -i32::from(e), |v| { - block_on((*v).setsockopt(SocketOption::TcpNoDelay, value != 0), None) - .map_or_else(|e| -i32::from(e), |()| 0) + block_on( + async { + v.read() + .await + .setsockopt(SocketOption::TcpNoDelay, value != 0) + .await + }, + None, + ) + .map_or_else(|e| -i32::from(e), |()| 0) }, ) } else { @@ -933,7 +960,11 @@ pub unsafe extern "C" fn sys_getsockopt( obj.map_or_else( |e| -i32::from(e), |v| { - block_on((*v).getsockopt(SocketOption::TcpNoDelay), None).map_or_else( + block_on( + async { v.read().await.getsockopt(SocketOption::TcpNoDelay).await }, + None, + ) + .map_or_else( |e| -i32::from(e), |value| { if value { @@ -964,7 +995,8 @@ pub unsafe extern "C" fn sys_getpeername( obj.map_or_else( |e| -i32::from(e), |v| { - if let Ok(Some(endpoint)) = block_on((*v).getpeername(), None) { + if let Ok(Some(endpoint)) = block_on(async { v.read().await.getpeername().await }, None) + { if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; @@ -1019,7 +1051,10 @@ fn shutdown(sockfd: i32, how: i32) -> i32 { let obj = get_object(sockfd); obj.map_or_else( |e| -i32::from(e), - |v| block_on((*v).shutdown(how), None).map_or_else(|e| -i32::from(e), |()| 0), + |v| { + block_on(async { v.read().await.shutdown(how).await }, None) + .map_or_else(|e| -i32::from(e), |()| 0) + }, ) } @@ -1098,7 +1133,7 @@ pub unsafe extern "C" fn sys_sendto( obj.map_or_else( |e| isize::try_from(-i32::from(e)).unwrap(), |v| { - block_on((*v).sendto(slice, endpoint), None).map_or_else( + block_on(async { v.read().await.sendto(slice, endpoint).await }, None).map_or_else( |e| isize::try_from(-i32::from(e)).unwrap(), |v| v.try_into().unwrap(), ) @@ -1124,7 +1159,7 @@ pub unsafe extern "C" fn sys_recvfrom( obj.map_or_else( |e| isize::try_from(-i32::from(e)).unwrap(), |v| { - block_on((*v).recvfrom(slice), None).map_or_else( + block_on(async { v.read().await.recvfrom(slice).await }, None).map_or_else( |e| isize::try_from(-i32::from(e)).unwrap(), |(len, endpoint)| { if !addr.is_null() && !addrlen.is_null() {