Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,25 +233,27 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {

/// `accept` a connection on a socket
#[cfg(any(feature = "net", feature = "vsock"))]
async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
async fn accept(
&mut self,
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
Err(Errno::Inval)
}

/// initiate a connection on a socket
#[cfg(any(feature = "net", feature = "vsock"))]
async fn connect(&self, _endpoint: Endpoint) -> io::Result<()> {
async fn connect(&mut self, _endpoint: Endpoint) -> io::Result<()> {
Err(Errno::Inval)
}

/// `bind` a name to a socket
#[cfg(any(feature = "net", feature = "vsock"))]
async fn bind(&self, _name: ListenEndpoint) -> io::Result<()> {
async fn bind(&mut self, _name: ListenEndpoint) -> io::Result<()> {
Err(Errno::Inval)
}

/// `listen` for connections on a socket
#[cfg(any(feature = "net", feature = "vsock"))]
async fn listen(&self, _backlog: i32) -> io::Result<()> {
async fn listen(&mut self, _backlog: i32) -> io::Result<()> {
Err(Errno::Inval)
}

Expand Down Expand Up @@ -310,7 +312,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug {
}

/// Sets the file status flags.
async fn set_status_flags(&self, _status_flags: StatusFlags) -> io::Result<()> {
async fn set_status_flags(&mut self, _status_flags: StatusFlags) -> io::Result<()> {
Err(Errno::Nosys)
}

Expand All @@ -337,19 +339,19 @@ pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result<usize> {
return Ok(0);
}

block_on(obj.read(buf), None)
block_on(async { obj.read().await.read(buf).await }, None)
}

pub(crate) fn lseek(fd: FileDescriptor, offset: isize, whence: SeekWhence) -> io::Result<isize> {
let obj = get_object(fd)?;

block_on(obj.lseek(offset, whence), None)
block_on(async { obj.read().await.lseek(offset, whence).await }, None)
}

pub(crate) fn chmod(fd: FileDescriptor, mode: AccessPermission) -> io::Result<()> {
let obj = get_object(fd)?;

block_on(obj.chmod(mode), None)
block_on(async { obj.read().await.chmod(mode).await }, None)
}

pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result<usize> {
Expand All @@ -359,12 +361,12 @@ pub(crate) fn write(fd: FileDescriptor, buf: &[u8]) -> io::Result<usize> {
return Ok(0);
}

block_on(obj.write(buf), None)
block_on(async { obj.read().await.write(buf).await }, None)
}

pub(crate) fn truncate(fd: FileDescriptor, length: usize) -> io::Result<()> {
let obj = get_object(fd)?;
block_on(obj.truncate(length), None)
block_on(async { obj.read().await.truncate(length).await }, None)
}

async fn poll_fds(fds: &mut [PollFd]) -> io::Result<u64> {
Expand All @@ -375,7 +377,7 @@ async fn poll_fds(fds: &mut [PollFd]) -> io::Result<u64> {
let fd = i.fd;
i.revents = PollEvent::empty();
if let Ok(obj) = core_scheduler().get_object(fd) {
let mut pinned = core::pin::pin!(obj.poll(i.events));
let mut pinned = core::pin::pin!(async { obj.read().await.poll(i.events).await });
if let Ready(Ok(e)) = pinned.as_mut().poll(cx)
&& !e.is_empty()
{
Expand Down Expand Up @@ -416,7 +418,7 @@ pub fn poll(fds: &mut [PollFd], timeout: Option<Duration>) -> io::Result<u64> {

pub fn fstat(fd: FileDescriptor) -> io::Result<FileAttr> {
let obj = get_object(fd)?;
block_on(obj.fstat(), None)
block_on(async { obj.read().await.fstat().await }, None)
}

/// Wait for some event on a file descriptor.
Expand All @@ -440,16 +442,20 @@ pub fn fstat(fd: FileDescriptor) -> io::Result<FileAttr> {
pub fn eventfd(initval: u64, flags: EventFlags) -> io::Result<FileDescriptor> {
let obj = self::eventfd::EventFd::new(initval, flags);

let fd = core_scheduler().insert_object(Arc::new(obj))?;
let fd = core_scheduler().insert_object(Arc::new(async_lock::RwLock::new(obj)))?;

Ok(fd)
}

pub(crate) fn get_object(fd: FileDescriptor) -> io::Result<Arc<dyn ObjectInterface>> {
pub(crate) fn get_object(
fd: FileDescriptor,
) -> io::Result<Arc<async_lock::RwLock<dyn ObjectInterface>>> {
core_scheduler().get_object(fd)
}

pub(crate) fn insert_object(obj: Arc<dyn ObjectInterface>) -> io::Result<FileDescriptor> {
pub(crate) fn insert_object(
obj: Arc<async_lock::RwLock<dyn ObjectInterface>>,
) -> io::Result<FileDescriptor> {
core_scheduler().insert_object(obj)
}

Expand All @@ -465,11 +471,13 @@ pub(crate) fn dup_object2(fd1: FileDescriptor, fd2: FileDescriptor) -> io::Resul
core_scheduler().dup_object2(fd1, fd2)
}

pub(crate) fn remove_object(fd: FileDescriptor) -> io::Result<Arc<dyn ObjectInterface>> {
pub(crate) fn remove_object(
fd: FileDescriptor,
) -> io::Result<Arc<async_lock::RwLock<dyn ObjectInterface>>> {
core_scheduler().remove_object(fd)
}

pub(crate) fn isatty(fd: FileDescriptor) -> io::Result<bool> {
let obj = get_object(fd)?;
block_on(obj.isatty(), None)
block_on(async { obj.read().await.isatty().await }, None)
}
71 changes: 8 additions & 63 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ impl Socket {
})
.await
}
}

#[async_trait]
impl ObjectInterface for Socket {
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
future::poll_fn(|cx| {
self.with(|socket| match socket.state() {
Expand Down Expand Up @@ -275,7 +278,7 @@ impl Socket {
}
}

async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> {
#[allow(irrefutable_let_patterns)]
if let Endpoint::Ip(endpoint) = endpoint {
self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port()))
Expand All @@ -298,7 +301,9 @@ impl Socket {
}
}

async fn accept(&mut self) -> io::Result<(Socket, Endpoint)> {
async fn accept(
&mut self,
) -> io::Result<(Arc<async_lock::RwLock<dyn ObjectInterface>>, Endpoint)> {
if !self.is_listen {
self.listen(DEFAULT_BACKLOG).await?;
}
Expand Down Expand Up @@ -357,7 +362,7 @@ impl Socket {
is_listen: false,
};

Ok((socket, endpoint))
Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
}

async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
Expand Down Expand Up @@ -473,63 +478,3 @@ impl Drop for Socket {
}
}
}

#[async_trait]
impl ObjectInterface for async_lock::RwLock<Socket> {
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
self.read().await.poll(event).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

async fn write(&self, buffer: &[u8]) -> io::Result<usize> {
self.read().await.write(buffer).await
}

async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
self.write().await.bind(endpoint).await
}

async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
self.read().await.connect(endpoint).await
}

async fn accept(&self) -> io::Result<(Arc<dyn ObjectInterface>, Endpoint)> {
let (socket, endpoint) = self.write().await.accept().await?;
Ok((Arc::new(async_lock::RwLock::new(socket)), endpoint))
}

async fn getpeername(&self) -> io::Result<Option<Endpoint>> {
self.read().await.getpeername().await
}

async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
self.read().await.getsockname().await
}

async fn listen(&self, backlog: i32) -> io::Result<()> {
self.write().await.listen(backlog).await
}

async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
self.read().await.setsockopt(opt, optval).await
}

async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
self.read().await.getsockopt(opt).await
}

async fn shutdown(&self, how: i32) -> io::Result<()> {
self.read().await.shutdown(how).await
}

async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
self.read().await.status_flags().await
}

async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
self.write().await.set_status_flags(status_flags).await
}
}
46 changes: 3 additions & 43 deletions src/fd/socket/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ impl Socket {
})
.await
}
}

#[async_trait]
impl ObjectInterface for Socket {
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
future::poll_fn(|cx| {
self.with(|socket| {
Expand Down Expand Up @@ -251,46 +254,3 @@ impl Drop for Socket {
NIC.lock().as_nic_mut().unwrap().destroy_socket(self.handle);
}
}

#[async_trait]
impl ObjectInterface for async_lock::RwLock<Socket> {
async fn poll(&self, event: PollEvent) -> io::Result<PollEvent> {
self.read().await.poll(event).await
}

async fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> {
self.write().await.bind(endpoint).await
}

async fn connect(&self, endpoint: Endpoint) -> io::Result<()> {
self.write().await.connect(endpoint).await
}

async fn sendto(&self, buffer: &[u8], endpoint: Endpoint) -> io::Result<usize> {
self.read().await.sendto(buffer, endpoint).await
}

async fn recvfrom(&self, buffer: &mut [MaybeUninit<u8>]) -> io::Result<(usize, Endpoint)> {
self.read().await.recvfrom(buffer).await
}

async fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
self.read().await.read(buffer).await
}

async fn write(&self, buf: &[u8]) -> io::Result<usize> {
self.read().await.write(buf).await
}

async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
self.read().await.getsockname().await
}

async fn status_flags(&self) -> io::Result<fd::StatusFlags> {
self.read().await.status_flags().await
}

async fn set_status_flags(&self, status_flags: fd::StatusFlags) -> io::Result<()> {
self.write().await.set_status_flags(status_flags).await
}
}
Loading