diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index f66f05fdb1cef7..05d08f567b62c6 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -35,11 +35,20 @@ private sealed class State public MsQuicConnection.State ConnectionState = null!; // set in ctor. public ReadState ReadState; + + // set when ReadState.Aborted: public long ReadErrorCode = -1; - public readonly List ReceiveQuicBuffers = new List(); - // Resettable completions to be used for multiple calls to receive. - public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); + // filled when ReadState.BuffersAvailable: + public QuicBuffer[] ReceiveQuicBuffers = Array.Empty(); + public int ReceiveQuicBuffersCount; + public int ReceiveQuicBuffersTotalBytes; + + // set when ReadState.PendingRead: + public Memory ReceiveUserBuffer; + public CancellationTokenRegistration ReceiveCancellationRegistration; + public MsQuicStream? RootedReceiveStream; // roots the stream in the pinned state to prevent GC during an async read I/O. + public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); public SendState SendState; public long SendErrorCode = -1; @@ -342,7 +351,7 @@ private void HandleWriteFailedState() } } - internal override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + internal override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) { ThrowIfDisposed(); @@ -351,109 +360,151 @@ internal override async ValueTask ReadAsync(Memory destination, Cance throw new InvalidOperationException(SR.net_quic_reading_notallowed); } - if (cancellationToken.IsCancellationRequested) - { - lock (_state) - { - if (_state.ReadState == ReadState.None) - { - _state.ReadState = ReadState.Aborted; - } - } - - throw new OperationCanceledException(cancellationToken); - } - if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] reading into Memory of '{destination.Length}' bytes."); } + ReadState readState; + long abortError = -1; + bool canceledSynchronously = false; + lock (_state) { - if (_state.ReadState == ReadState.ReadsCompleted) - { - return 0; - } - else if (_state.ReadState == ReadState.Aborted) + readState = _state.ReadState; + abortError = _state.ReadErrorCode; + + if (readState != ReadState.PendingRead && cancellationToken.IsCancellationRequested) { - throw ThrowHelper.GetStreamAbortedException(_state.ReadErrorCode); + readState = ReadState.Aborted; + _state.ReadState = ReadState.Aborted; + canceledSynchronously = true; } - else if (_state.ReadState == ReadState.ConnectionClosed) + else if (readState == ReadState.None) { - throw GetConnectionAbortedException(_state); - } - } + Debug.Assert(_state.RootedReceiveStream is null); - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ReadState == ReadState.None) + _state.ReceiveUserBuffer = destination; + _state.RootedReceiveStream = this; + _state.ReadState = ReadState.PendingRead; + + if (cancellationToken.CanBeCanceled) { - shouldComplete = true; + _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) => + { + var state = (State)obj!; + bool completePendingRead; + + lock (state) + { + completePendingRead = state.ReadState == ReadState.PendingRead; + state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; + state.ReadState = ReadState.Aborted; + } + + if (completePendingRead) + { + state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(token))); + } + }, _state); + } + else + { + _state.ReceiveCancellationRegistration = default; } - state.ReadState = ReadState.Aborted; - } - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Read was canceled", token))); + return _state.ReceiveResettableCompletionSource.GetValueTask(); } - }, _state); - - // TODO there could potentially be a perf gain by storing the buffer from the initial read - // This reduces the amount of async calls, however it makes it so MsQuic holds onto the buffers - // longer than it needs to. We will need to benchmark this. - int length = (int)await _state.ReceiveResettableCompletionSource.GetValueTask().ConfigureAwait(false); + else if (readState == ReadState.IndividualReadComplete) + { + _state.ReadState = ReadState.None; - int actual = Math.Min(length, destination.Length); + int taken = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span); + ReceiveComplete(taken); - static unsafe void CopyToBuffer(Span destinationBuffer, List sourceBuffers) - { - Span slicedBuffer = destinationBuffer; - for (int i = 0; i < sourceBuffers.Count; i++) - { - QuicBuffer nativeBuffer = sourceBuffers[i]; - int length = Math.Min((int)nativeBuffer.Length, slicedBuffer.Length); - new Span(nativeBuffer.Buffer, length).CopyTo(slicedBuffer); - if (length < nativeBuffer.Length) + if (taken != _state.ReceiveQuicBuffersTotalBytes) { - // The buffer passed in was larger that the received data, return - return; + // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer. + EnableReceive(); } - slicedBuffer = slicedBuffer.Slice(length); + + return new ValueTask(taken); } } - CopyToBuffer(destination.Span, _state.ReceiveQuicBuffers); + Exception? ex = null; - lock (_state) + switch (readState) { - if (_state.ReadState == ReadState.IndividualReadComplete) - { - _state.ReceiveQuicBuffers.Clear(); - ReceiveComplete(actual); - EnableReceive(); - _state.ReadState = ReadState.None; - } + case ReadState.ReadsCompleted: + return new ValueTask(0); + case ReadState.PendingRead: + ex = new InvalidOperationException("Only one read is supported at a time."); + break; + case ReadState.Aborted: + ex = + canceledSynchronously ? new OperationCanceledException(cancellationToken) : // aborted by token being canceled before the async op started. + abortError == -1 ? new QuicOperationAbortedException() : // aborted by user via some other operation. + new QuicStreamAbortedException(abortError); // aborted by peer. + + break; + case ReadState.ConnectionClosed: + default: + Debug.Assert(readState == ReadState.ConnectionClosed, $"{nameof(ReadState)} of '{readState}' is unaccounted for in {nameof(ReadAsync)}."); + ex = GetConnectionAbortedException(_state); + break; } - return actual; + return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(ex!)); + } + + /// The number of bytes copied. + private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan sourceBuffers, Span destinationBuffer) + { + Debug.Assert(sourceBuffers.Length != 0); + + int originalDestinationLength = destinationBuffer.Length; + QuicBuffer nativeBuffer; + int takeLength = 0; + int i = 0; + + do + { + nativeBuffer = sourceBuffers[i]; + takeLength = Math.Min((int)nativeBuffer.Length, destinationBuffer.Length); + + new Span(nativeBuffer.Buffer, takeLength).CopyTo(destinationBuffer); + destinationBuffer = destinationBuffer.Slice(takeLength); + } + while (destinationBuffer.Length != 0 && ++i < sourceBuffers.Length); + + return originalDestinationLength - destinationBuffer.Length; } - // TODO do we want this to be a synchronization mechanism to cancel a pending read - // If so, we need to complete the read here as well. internal override void AbortRead(long errorCode) { ThrowIfDisposed(); + bool shouldComplete = false; lock (_state) { - _state.ReadState = ReadState.Aborted; + if (_state.ReadState == ReadState.PendingRead) + { + shouldComplete = true; + _state.RootedReceiveStream = null; + _state.ReceiveUserBuffer = null; + } + if (_state.ReadState < ReadState.ReadsCompleted) + { + _state.ReadState = ReadState.Aborted; + } + } + + if (shouldComplete) + { + _state.ReceiveResettableCompletionSource.CompleteException( + ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Read was aborted"))); } StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, errorCode); @@ -704,7 +755,8 @@ private void Dispose(bool disposing) private void EnableReceive() { - MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + uint status = MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + QuicExceptionHelpers.ThrowIfFailed(status, "StreamReceiveSetEnabled failed."); } private static uint NativeCallbackHandler( @@ -778,31 +830,80 @@ private static uint HandleEvent(State state, ref StreamEvent evt) private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) { - StreamEventDataReceive receiveEvent = evt.Data.Receive; - for (int i = 0; i < receiveEvent.BufferCount; i++) + ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive; + + if (receiveEvent.BufferCount == 0) { - state.ReceiveQuicBuffers.Add(receiveEvent.Buffers[i]); + // This is a 0-length receive that happens once reads are finished (via abort or otherwise). + // State changes for this are handled elsewhere. + return MsQuicStatusCodes.Success; } + int readLength; + bool shouldComplete = false; lock (state) { - if (state.ReadState == ReadState.None) + switch (state.ReadState) { - shouldComplete = true; - } - if (state.ReadState != ReadState.ConnectionClosed) - { - state.ReadState = ReadState.IndividualReadComplete; + case ReadState.None: + // ReadAsync() hasn't been called yet. Stash the buffer so the next ReadAsync call completes synchronously. + + if ((uint)state.ReceiveQuicBuffers.Length < receiveEvent.BufferCount) + { + QuicBuffer[] oldReceiveBuffers = state.ReceiveQuicBuffers; + state.ReceiveQuicBuffers = ArrayPool.Shared.Rent((int)receiveEvent.BufferCount); + + if (oldReceiveBuffers.Length != 0) // don't return Array.Empty. + { + ArrayPool.Shared.Return(oldReceiveBuffers); + } + } + + for (uint i = 0; i < receiveEvent.BufferCount; ++i) + { + state.ReceiveQuicBuffers[i] = receiveEvent.Buffers[i]; + } + + state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount; + state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength); + state.ReadState = ReadState.IndividualReadComplete; + return MsQuicStatusCodes.Pending; + case ReadState.PendingRead: + // There is a pending ReadAsync(). + + state.ReceiveCancellationRegistration.Unregister(); + shouldComplete = true; + state.RootedReceiveStream = null; + state.ReadState = ReadState.None; + + readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); + state.ReceiveUserBuffer = null; + break; + default: + Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); + + // There was a race between a user aborting the read stream and the callback being ran. + // This will eat any received data. + return MsQuicStatusCodes.Success; } } + // We're completing a pending read. if (shouldComplete) { - state.ReceiveResettableCompletionSource.Complete((uint)receiveEvent.TotalBufferLength); + state.ReceiveResettableCompletionSource.Complete(readLength); } - return MsQuicStatusCodes.Pending; + // Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called. + // Returning Continue will cause a second receive event to fire immediately after this returns, but allows MsQuic to clean up its buffers. + + uint ret = (uint)readLength == receiveEvent.TotalBufferLength + ? MsQuicStatusCodes.Success + : MsQuicStatusCodes.Continue; + + receiveEvent.TotalBufferLength = (uint)readLength; + return ret; } private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) @@ -873,12 +974,13 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source."); - if (state.ReadState == ReadState.None) + if (state.ReadState == ReadState.PendingRead) { shouldReadComplete = true; + state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; } - - if (state.ReadState != ReadState.ConnectionClosed && state.ReadState != ReadState.Aborted) + if (state.ReadState < ReadState.ReadsCompleted) { state.ReadState = ReadState.ReadsCompleted; } @@ -926,9 +1028,11 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) bool shouldComplete = false; lock (state) { - if (state.ReadState == ReadState.None) + if (state.ReadState == ReadState.PendingRead) { shouldComplete = true; + state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; } state.ReadState = ReadState.Aborted; state.ReadErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; @@ -952,12 +1056,13 @@ private static uint HandleEventPeerSendShutdown(State state) // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source."); - if (state.ReadState == ReadState.None) + if (state.ReadState == ReadState.PendingRead) { shouldComplete = true; + state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; } - - if (state.ReadState != ReadState.ConnectionClosed) + if (state.ReadState < ReadState.ReadsCompleted) { state.ReadState = ReadState.ReadsCompleted; } @@ -1251,11 +1356,11 @@ private static uint HandleEventConnectionClose(State state) lock (state) { - if (state.ReadState == ReadState.None) + shouldCompleteRead = state.ReadState == ReadState.PendingRead; + if (state.ReadState < ReadState.ReadsCompleted) { - shouldCompleteRead = true; + state.ReadState = ReadState.ConnectionClosed; } - state.ReadState = ReadState.ConnectionClosed; if (state.SendState == SendState.None || state.SendState == SendState.Pending) { @@ -1306,6 +1411,16 @@ private static uint HandleEventConnectionClose(State state) private static Exception GetConnectionAbortedException(State state) => ThrowHelper.GetConnectionAbortedException(state.ConnectionState.AbortErrorCode); + // Read state transitions: + // + // None --(data arrives in event RECV)-> IndividualReadComplete --(user calls ReadAsync() & completes syncronously)-> None + // None --(user calls ReadAsync() & waits)-> PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> None + // Any non-final state --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)-> ReadsCompleted + // Any non-final state --(event PEER_SEND_ABORT)-> Aborted + // Any non-final state --(user calls AbortRead())-> Aborted + // Any state --(CancellationToken's cancellation for ReadAsync())-> Aborted (TODO: should it be only for non-final as others?) + // Any non-final state --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)-> ConnectionClosed + // Closed - no transitions, set for Unidirectional write-only streams private enum ReadState { /// @@ -1318,6 +1433,13 @@ private enum ReadState /// IndividualReadComplete, + /// + /// User called ReadAsync() + /// + PendingRead, + + // following states are final: + /// /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. /// diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index b21135000aca65..4890ebea5f352d 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -408,32 +408,122 @@ from writeSize in sizes } [Fact] - public async Task Read_StreamAborted_Throws() + public async Task Read_WriteAborted_Throws() { const int ExpectedErrorCode = 0xfffffff; - await Task.Run(async () => + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + + await sem.WaitAsync(); + clientStream.AbortWrite(ExpectedErrorCode); + }, + async serverStream => + { + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + + sem.Release(); + + byte[] buffer = new byte[100]; + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); + } + + [Fact] + public async Task Read_SynchronousCompletion_Success() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + sem.Release(); + clientStream.Shutdown(); + sem.Release(); + }, + async serverStream => + { + await sem.WaitAsync(); + await Task.Delay(1000); + + ValueTask task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + int received = await task; + Assert.Equal(1, received); + + await sem.WaitAsync(); + await Task.Delay(1000); + + task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + received = await task; + Assert.Equal(0, received); + }); + } + + [Fact] + public async Task ReadOutstanding_ReadAborted_Throws() + { + // aborting doesn't work properly on mock + if (typeof(T) == typeof(MockProviderFactory)) { - using QuicListener listener = CreateQuicListener(); - ValueTask serverConnectionTask = listener.AcceptConnectionAsync(); + return; + } + + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + Task exTask = Assert.ThrowsAsync(() => serverStream.ReadAsync(new byte[1]).AsTask()); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); + Assert.False(exTask.IsCompleted); - using QuicConnection serverConnection = await serverConnectionTask; + serverStream.AbortRead(ExpectedErrorCode); - await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); - await clientStream.WriteAsync(new byte[1]); + await exTask; - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); - await serverStream.ReadAsync(new byte[1]); + sem.Release(); + }); + } - clientStream.AbortWrite(ExpectedErrorCode); + [Fact] + public async Task Read_ConcurrentReads_Throws() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); - byte[] buffer = new byte[100]; - QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); - Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - }).WaitAsync(TimeSpan.FromSeconds(15)); + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + ValueTask readTask = serverStream.ReadAsync(new byte[1]); + Assert.False(readTask.IsCompleted); + + await Assert.ThrowsAsync(async () => await serverStream.ReadAsync(new byte[1])); + + sem.Release(); + + int res = await readTask; + Assert.Equal(0, res); + }); } [Fact] diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 803b2d40705921..da2cfb37f412cd 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Xunit; +using System.Diagnostics.Tracing; namespace System.Net.Quic.Tests { @@ -114,22 +115,19 @@ internal async Task RunClientServer(Func clientFunction, F { using QuicListener listener = CreateQuicListener(); - var serverFinished = new ManualResetEventSlim(); - var clientFinished = new ManualResetEventSlim(); + using var serverFinished = new SemaphoreSlim(0); + using var clientFinished = new SemaphoreSlim(0); for (int i = 0; i < iterations; ++i) { - serverFinished.Reset(); - clientFinished.Reset(); - await new[] { Task.Run(async () => { using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); await serverFunction(serverConnection); - serverFinished.Set(); - clientFinished.Wait(); + serverFinished.Release(); + await clientFinished.WaitAsync(); await serverConnection.CloseAsync(0); }), Task.Run(async () => @@ -137,14 +135,52 @@ await new[] using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); await clientConnection.ConnectAsync(); await clientFunction(clientConnection); - clientFinished.Set(); - serverFinished.Wait(); + clientFinished.Release(); + await serverFinished.WaitAsync(); await clientConnection.CloseAsync(0); }) }.WhenAllOrAnyFailed(millisecondsTimeout); } } + internal async Task RunStreamClientServer(Func clientFunction, Func serverFunction, bool bidi, int iterations, int millisecondsTimeout) + { + byte[] buffer = new byte[1] { 42 }; + + await RunClientServer( + clientFunction: async connection => + { + await using QuicStream stream = bidi ? connection.OpenBidirectionalStream() : connection.OpenUnidirectionalStream(); + // Open(Bi|Uni)directionalStream only allocates ID. We will force stream opening + // by Writing there and receiving data on the other side. + await stream.WriteAsync(buffer); + + await clientFunction(stream); + + stream.Shutdown(); + await stream.ShutdownCompleted(); + }, + serverFunction: async connection => + { + await using QuicStream stream = await connection.AcceptStreamAsync(); + Assert.Equal(1, await stream.ReadAsync(buffer)); + + await serverFunction(stream); + + stream.Shutdown(); + await stream.ShutdownCompleted(); + }, + iterations, + millisecondsTimeout + ); + } + + internal Task RunBidirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunStreamClientServer(clientFunction, serverFunction, bidi: true, iterations, millisecondsTimeout); + + internal Task RunUnirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunStreamClientServer(clientFunction, serverFunction, bidi: false, iterations, millisecondsTimeout); + internal static async Task ReadAll(QuicStream stream, byte[] buffer) { Memory memory = buffer;