From da979d4441d265a9332422c039d9bae8f341e89c Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Sun, 11 Jul 2021 22:51:58 +0200 Subject: [PATCH 1/5] Update read process and add tests --- .../Quic/Implementations/Mock/MockStream.cs | 11 +- .../Implementations/MsQuic/MsQuicStream.cs | 258 ++++++++++++------ .../tests/FunctionalTests/QuicStreamTests.cs | 118 ++++++-- .../tests/FunctionalTests/QuicTestBase.cs | 41 ++- 4 files changed, 322 insertions(+), 106 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index bd814f690d952c..3537db4d0370c2 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -171,7 +171,16 @@ internal override Task FlushAsync(CancellationToken cancellationToken) internal override void AbortRead(long errorCode) { - throw new NotImplementedException(); + if (_isInitiator) + { + _streamState._outboundErrorCode = errorCode; + } + else + { + _streamState._inboundErrorCode = errorCode; + } + + ReadStreamBuffer?.AbortRead(); } internal override void AbortWrite(long errorCode) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 8260dfdd68fa74..1864360ac317fa 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -35,11 +35,20 @@ private sealed class State public MsQuicConnection.State ConnectionState = null!; // set in ctor. public ReadState ReadState; + + // set when ReadState.Aborted: public long ReadErrorCode = -1; - public readonly List ReceiveQuicBuffers = new List(); - // Resettable completions to be used for multiple calls to receive. - public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); + // filled when ReadState.BuffersAvailable: + public QuicBuffer[] ReceiveQuicBuffers = Array.Empty(); + public int ReceiveQuicBuffersCount; + public int ReceiveQuicBuffersTotalBytes; + + // set when ReadState.PendingRead: + public Memory ReceiveUserBuffer; + public CancellationTokenRegistration ReceiveCancellationRegistration; + public MsQuicStream? RootedReceiveStream; // roots the stream in the pinned state to prevent GC during an async read I/O. + public readonly ResettableCompletionSource ReceiveResettableCompletionSource = new ResettableCompletionSource(); public SendState SendState; public long SendErrorCode = -1; @@ -325,7 +334,7 @@ private void HandleWriteFailedState() } } - internal override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + internal override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) { ThrowIfDisposed(); @@ -352,80 +361,120 @@ internal override async ValueTask ReadAsync(Memory destination, Cance NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] reading into Memory of '{destination.Length}' bytes."); } + ReadState readState; + long abortError = -1; + bool canceledSynchronously = false; + lock (_state) { - if (_state.ReadState == ReadState.ReadsCompleted) - { - return 0; - } - else if (_state.ReadState == ReadState.Aborted) + readState = _state.ReadState; + abortError = _state.ReadErrorCode; + + if (readState != ReadState.PendingRead && cancellationToken.IsCancellationRequested) { - throw ThrowHelper.GetStreamAbortedException(_state.ReadErrorCode); + readState = ReadState.Aborted; + _state.ReadState = ReadState.Aborted; + canceledSynchronously = true; } - else if (_state.ReadState == ReadState.ConnectionClosed) + else if (readState == ReadState.None) { - throw GetConnectionAbortedException(_state); - } - } + Debug.Assert(_state.RootedReceiveStream is null); - using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) => - { - var state = (State)s!; - bool shouldComplete = false; - lock (state) - { - if (state.ReadState == ReadState.None) + _state.ReceiveUserBuffer = destination; + _state.RootedReceiveStream = this; + _state.ReadState = ReadState.PendingRead; + + if (cancellationToken.CanBeCanceled) { - shouldComplete = true; + _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) => + { + var state = (State)obj!; + bool completePendingRead; + + lock (state) + { + completePendingRead = state.ReadState == ReadState.PendingRead; + state.RootedReceiveStream = null; + state.ReadState = ReadState.Aborted; + } + + if (completePendingRead) + { + state.ReceiveResettableCompletionSource.CompleteException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(token))); + } + }, _state); + } + else + { + _state.ReceiveCancellationRegistration = default; } - state.ReadState = ReadState.Aborted; - } - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.CompleteException( - ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Read was canceled", token))); + return _state.ReceiveResettableCompletionSource.GetValueTask(); } - }, _state); - - // TODO there could potentially be a perf gain by storing the buffer from the initial read - // This reduces the amount of async calls, however it makes it so MsQuic holds onto the buffers - // longer than it needs to. We will need to benchmark this. - int length = (int)await _state.ReceiveResettableCompletionSource.GetValueTask().ConfigureAwait(false); + else if (readState == ReadState.IndividualReadComplete) + { + _state.ReadState = ReadState.None; - int actual = Math.Min(length, destination.Length); + int taken = CopyMsQuicBuffersToUserBuffer(_state.ReceiveQuicBuffers.AsSpan(0, _state.ReceiveQuicBuffersCount), destination.Span); + ReceiveComplete(taken); - static unsafe void CopyToBuffer(Span destinationBuffer, List sourceBuffers) - { - Span slicedBuffer = destinationBuffer; - for (int i = 0; i < sourceBuffers.Count; i++) - { - QuicBuffer nativeBuffer = sourceBuffers[i]; - int length = Math.Min((int)nativeBuffer.Length, slicedBuffer.Length); - new Span(nativeBuffer.Buffer, length).CopyTo(slicedBuffer); - if (length < nativeBuffer.Length) + if (taken != _state.ReceiveQuicBuffersTotalBytes) { - // The buffer passed in was larger that the received data, return - return; + // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer. + EnableReceive(); } - slicedBuffer = slicedBuffer.Slice(length); + + return new ValueTask(taken); } } - CopyToBuffer(destination.Span, _state.ReceiveQuicBuffers); + Exception? ex = null; - lock (_state) + switch (readState) { - if (_state.ReadState == ReadState.IndividualReadComplete) - { - _state.ReceiveQuicBuffers.Clear(); - ReceiveComplete(actual); - EnableReceive(); - _state.ReadState = ReadState.None; - } + case ReadState.ReadsCompleted: + return new ValueTask(0); + case ReadState.PendingRead: + ex = new InvalidOperationException("Only one read is supported at a time."); + break; + case ReadState.Aborted: + ex = + canceledSynchronously ? new OperationCanceledException(cancellationToken) : // aborted by token being canceled before the async op started. + abortError == -1 ? new QuicOperationAbortedException() : // aborted by user via some other operation. + new QuicStreamAbortedException(abortError); // aborted by peer. + + break; + case ReadState.ConnectionClosed: + default: + Debug.Assert(readState == ReadState.ConnectionClosed, $"{nameof(ReadState)} of '{readState}' is unaccounted for in {nameof(ReadAsync)}."); + ex = GetConnectionAbortedException(_state); + break; } - return actual; + return ValueTask.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(ex!)); + } + + /// The number of bytes copied. + private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan sourceBuffers, Span destinationBuffer) + { + Debug.Assert(sourceBuffers.Length != 0); + + int originalDestinationLength = destinationBuffer.Length; + QuicBuffer nativeBuffer; + int takeLength = 0; + int i = 0; + + do + { + nativeBuffer = sourceBuffers[i]; + takeLength = Math.Min((int)nativeBuffer.Length, destinationBuffer.Length); + + new Span(nativeBuffer.Buffer, takeLength).CopyTo(destinationBuffer); + destinationBuffer = destinationBuffer.Slice(takeLength); + } + while (destinationBuffer.Length != 0 && ++i < sourceBuffers.Length); + + return originalDestinationLength - destinationBuffer.Length; } // TODO do we want this to be a synchronization mechanism to cancel a pending read @@ -687,7 +736,8 @@ private void Dispose(bool disposing) private void EnableReceive() { - MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + uint status = MsQuicApi.Api.StreamReceiveSetEnabledDelegate(_state.Handle, enabled: true); + QuicExceptionHelpers.ThrowIfFailed(status, "StreamReceiveSetEnabled failed."); } private static uint NativeCallbackHandler( @@ -763,31 +813,75 @@ private static uint HandleEvent(State state, ref StreamEvent evt) private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) { - StreamEventDataReceive receiveEvent = evt.Data.Receive; - for (int i = 0; i < receiveEvent.BufferCount; i++) + ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive; + + if (receiveEvent.BufferCount == 0) { - state.ReceiveQuicBuffers.Add(receiveEvent.Buffers[i]); + // This is a 0-length receive that happens once reads are finished (via abort or otherwise). + // State changes for this are handled elsewhere. + return MsQuicStatusCodes.Success; } - bool shouldComplete = false; + int readLength; + lock (state) { - if (state.ReadState == ReadState.None) + switch (state.ReadState) { - shouldComplete = true; - } - if (state.ReadState != ReadState.ConnectionClosed) - { - state.ReadState = ReadState.IndividualReadComplete; + case ReadState.None: + // ReadAsync() hasn't been called yet. Stash the buffer so the next ReadAsync call completes synchronously. + + if ((uint)state.ReceiveQuicBuffers.Length < receiveEvent.BufferCount) + { + QuicBuffer[] oldReceiveBuffers = state.ReceiveQuicBuffers; + state.ReceiveQuicBuffers = ArrayPool.Shared.Rent((int)receiveEvent.BufferCount); + + if (oldReceiveBuffers.Length != 0) // don't return Array.Empty. + { + ArrayPool.Shared.Return(oldReceiveBuffers); + } + } + + for (uint i = 0; i < receiveEvent.BufferCount; ++i) + { + state.ReceiveQuicBuffers[i] = receiveEvent.Buffers[i]; + } + + state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount; + state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength); + state.ReadState = ReadState.IndividualReadComplete; + return MsQuicStatusCodes.Pending; + case ReadState.PendingRead: + // There is a pending ReadAsync(). + + state.ReceiveCancellationRegistration.Unregister(); + state.RootedReceiveStream = null; + state.ReadState = ReadState.None; + + readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); + break; + default: + Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); + + // There was a race between a user aborting the read stream and the callback being ran. + // This will eat any received data. + return MsQuicStatusCodes.Success; } } - if (shouldComplete) - { - state.ReceiveResettableCompletionSource.Complete((uint)receiveEvent.TotalBufferLength); - } + // We're completing a pending read. + // TODO: only if ReadState.PendingRead?? + state.ReceiveResettableCompletionSource.Complete(readLength); + + // Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called. + // Returning Continue will cause a second receive event to fire immediately after this returns, but allows MsQuic to clean up its buffers. - return MsQuicStatusCodes.Pending; + uint ret = (uint)readLength == receiveEvent.TotalBufferLength + ? MsQuicStatusCodes.Success + : MsQuicStatusCodes.Continue; + + receiveEvent.TotalBufferLength = (uint)readLength; + return ret; } private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) @@ -949,12 +1043,13 @@ private static uint HandleEventPeerSendShutdown(State state) // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source."); - if (state.ReadState == ReadState.None) + if (state.ReadState == ReadState.PendingRead) { shouldComplete = true; + state.RootedReceiveStream = null; + state.ReadState = ReadState.ReadsCompleted; } - - if (state.ReadState != ReadState.ConnectionClosed) + else if (state.ReadState == ReadState.None) { state.ReadState = ReadState.ReadsCompleted; } @@ -1248,11 +1343,11 @@ private static uint HandleEventConnectionClose(State state) lock (state) { - if (state.ReadState == ReadState.None) + shouldCompleteRead = state.ReadState == ReadState.PendingRead; + if (state.ReadState < ReadState.ReadsCompleted) { - shouldCompleteRead = true; + state.ReadState = ReadState.ConnectionClosed; } - state.ReadState = ReadState.ConnectionClosed; if (state.SendState == SendState.None || state.SendState == SendState.Pending) { @@ -1315,6 +1410,11 @@ private enum ReadState /// IndividualReadComplete, + /// + /// User called ReadAsync() + /// + PendingRead, + /// /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. /// diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index dacd04a448cbe4..daa71f1e22229d 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -408,32 +408,116 @@ from writeSize in sizes } [Fact] - public async Task Read_StreamAborted_Throws() + public async Task Read_WriteAborted_Throws() { const int ExpectedErrorCode = 0xfffffff; - await Task.Run(async () => - { - using QuicListener listener = CreateQuicListener(); - ValueTask serverConnectionTask = listener.AcceptConnectionAsync(); + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + + await sem.WaitAsync(); + clientStream.AbortWrite(ExpectedErrorCode); + }, + async serverStream => + { + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + + sem.Release(); + + byte[] buffer = new byte[100]; + QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); + } + + [Fact] + public async Task Read_SynchronousCompletion_Success() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1]); + sem.Release(); + clientStream.Shutdown(); + sem.Release(); + }, + async serverStream => + { + await sem.WaitAsync(); + await Task.Delay(1000); + + ValueTask task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + int received = await task; + Assert.Equal(1, received); + + await sem.WaitAsync(); + await Task.Delay(1000); + + task = serverStream.ReadAsync(new byte[1]); + Assert.True(task.IsCompleted); + + received = await task; + Assert.Equal(0, received); + }); + } + + [Fact] + public async Task ReadOutstanding_ReadAborted_Throws() + { + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + Task exTask = Assert.ThrowsAsync(() => serverStream.ReadAsync(new byte[1]).AsTask()); - using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); - await clientConnection.ConnectAsync(); + Assert.False(exTask.IsCompleted); - using QuicConnection serverConnection = await serverConnectionTask; + serverStream.AbortRead(ExpectedErrorCode); - await using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); - await clientStream.WriteAsync(new byte[1]); + await exTask; - await using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); - await serverStream.ReadAsync(new byte[1]); + sem.Release(); + }); + } - clientStream.AbortWrite(ExpectedErrorCode); + [Fact] + public async Task Read_ConcurrentReads_Throws() + { + using SemaphoreSlim sem = new SemaphoreSlim(0); - byte[] buffer = new byte[100]; - QuicStreamAbortedException ex = await Assert.ThrowsAsync(() => serverStream.ReadAsync(buffer).AsTask()); - Assert.Equal(ExpectedErrorCode, ex.ErrorCode); - }).WaitAsync(TimeSpan.FromSeconds(15)); + await RunBidirectionalClientServer( + async clientStream => + { + await sem.WaitAsync(); + }, + async serverStream => + { + ValueTask readTask = serverStream.ReadAsync(new byte[1]); + Assert.False(readTask.IsCompleted); + + await Assert.ThrowsAsync(async () => await serverStream.ReadAsync(new byte[1])); + + sem.Release(); + + int res = await readTask; + Assert.Equal(0, res); + }); } [Fact] diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index ee7501868bebab..35805076eeb962 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -81,22 +81,19 @@ internal async Task RunClientServer(Func clientFunction, F { using QuicListener listener = CreateQuicListener(); - var serverFinished = new ManualResetEventSlim(); - var clientFinished = new ManualResetEventSlim(); + using var serverFinished = new SemaphoreSlim(0); + using var clientFinished = new SemaphoreSlim(0); for (int i = 0; i < iterations; ++i) { - serverFinished.Reset(); - clientFinished.Reset(); - await new[] { Task.Run(async () => { using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); await serverFunction(serverConnection); - serverFinished.Set(); - clientFinished.Wait(); + serverFinished.Release(); + await clientFinished.WaitAsync(); await serverConnection.CloseAsync(0); }), Task.Run(async () => @@ -104,14 +101,40 @@ await new[] using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); await clientConnection.ConnectAsync(); await clientFunction(clientConnection); - clientFinished.Set(); - serverFinished.Wait(); + clientFinished.Release(); + await serverFinished.WaitAsync(); await clientConnection.CloseAsync(0); }) }.WhenAllOrAnyFailed(millisecondsTimeout); } } + internal Task RunStreamClientServer(Func clientFunction, Func serverFunction, bool bidi, int iterations, int millisecondsTimeout) + { + return RunClientServer( + clientFunction: async connection => + { + await using QuicStream stream = bidi ? connection.OpenBidirectionalStream() : connection.OpenUnidirectionalStream(); + await clientFunction(stream); + stream.Shutdown(); + await stream.ShutdownCompleted(); + }, + serverFunction: async connection => + { + await using QuicStream stream = await connection.AcceptStreamAsync(); + await serverFunction(stream); + stream.Shutdown(); + await stream.ShutdownCompleted(); + } + ); + } + + internal Task RunBidirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunStreamClientServer(clientFunction, serverFunction, bidi: true, iterations, millisecondsTimeout); + + internal Task RunUnirectionalClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + => RunStreamClientServer(clientFunction, serverFunction, bidi: false, iterations, millisecondsTimeout); + internal static async Task ReadAll(QuicStream stream, byte[] buffer) { Memory memory = buffer; From 0a871bc998c28c87b236f0be352a4952586277d5 Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 12 Jul 2021 00:32:59 +0200 Subject: [PATCH 2/5] Fix quic tests --- .../Implementations/MsQuic/MsQuicStream.cs | 58 +++++++++++-------- .../tests/FunctionalTests/QuicStreamTests.cs | 6 ++ .../tests/FunctionalTests/QuicTestBase.cs | 18 +++++- 3 files changed, 55 insertions(+), 27 deletions(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 1864360ac317fa..9898b205e4490e 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -343,19 +343,6 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio throw new InvalidOperationException(SR.net_quic_reading_notallowed); } - if (cancellationToken.IsCancellationRequested) - { - lock (_state) - { - if (_state.ReadState == ReadState.None) - { - _state.ReadState = ReadState.Aborted; - } - } - - throw new System.OperationCanceledException(cancellationToken); - } - if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(_state, $"[Stream#{_state.GetHashCode()}] reading into Memory of '{destination.Length}' bytes."); @@ -395,6 +382,7 @@ internal override ValueTask ReadAsync(Memory destination, Cancellatio { completePendingRead = state.ReadState == ReadState.PendingRead; state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; state.ReadState = ReadState.Aborted; } @@ -477,15 +465,29 @@ private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan return originalDestinationLength - destinationBuffer.Length; } - // TODO do we want this to be a synchronization mechanism to cancel a pending read - // If so, we need to complete the read here as well. internal override void AbortRead(long errorCode) { ThrowIfDisposed(); + bool shouldComplete = false; lock (_state) { - _state.ReadState = ReadState.Aborted; + if (_state.ReadState == ReadState.PendingRead) + { + shouldComplete = true; + _state.RootedReceiveStream = null; + _state.ReceiveUserBuffer = null; + } + if (_state.ReadState < ReadState.ReadsCompleted) + { + _state.ReadState = ReadState.Aborted; + } + } + + if (shouldComplete) + { + _state.ReceiveResettableCompletionSource.CompleteException( + ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Read was aborted"))); } StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, errorCode); @@ -824,6 +826,7 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) int readLength; + bool shouldComplete = false; lock (state) { switch (state.ReadState) @@ -855,10 +858,12 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) // There is a pending ReadAsync(). state.ReceiveCancellationRegistration.Unregister(); + shouldComplete = true; state.RootedReceiveStream = null; state.ReadState = ReadState.None; readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span); + state.ReceiveUserBuffer = null; break; default: Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}."); @@ -870,8 +875,10 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) } // We're completing a pending read. - // TODO: only if ReadState.PendingRead?? - state.ReceiveResettableCompletionSource.Complete(readLength); + if (shouldComplete) + { + state.ReceiveResettableCompletionSource.Complete(readLength); + } // Returning Success when the entire buffer hasn't been consumed will cause MsQuic to disable further receive events until EnableReceive() is called. // Returning Continue will cause a second receive event to fire immediately after this returns, but allows MsQuic to clean up its buffers. @@ -964,12 +971,13 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt // This event won't occur within the middle of a receive. if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"[Stream#{state.GetHashCode()}] completing resettable event source."); - if (state.ReadState == ReadState.None) + if (state.ReadState == ReadState.PendingRead) { shouldReadComplete = true; + state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; } - - if (state.ReadState != ReadState.ConnectionClosed && state.ReadState != ReadState.Aborted) + if (state.ReadState < ReadState.ReadsCompleted) { state.ReadState = ReadState.ReadsCompleted; } @@ -1017,9 +1025,11 @@ private static uint HandleEventPeerSendAborted(State state, ref StreamEvent evt) bool shouldComplete = false; lock (state) { - if (state.ReadState == ReadState.None) + if (state.ReadState == ReadState.PendingRead) { shouldComplete = true; + state.RootedReceiveStream = null; + state.ReceiveUserBuffer = null; } state.ReadState = ReadState.Aborted; state.ReadErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode; @@ -1047,9 +1057,9 @@ private static uint HandleEventPeerSendShutdown(State state) { shouldComplete = true; state.RootedReceiveStream = null; - state.ReadState = ReadState.ReadsCompleted; + state.ReceiveUserBuffer = null; } - else if (state.ReadState == ReadState.None) + if (state.ReadState < ReadState.ReadsCompleted) { state.ReadState = ReadState.ReadsCompleted; } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index daa71f1e22229d..d2ef7b68135faa 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -473,6 +473,12 @@ await RunBidirectionalClientServer( [Fact] public async Task ReadOutstanding_ReadAborted_Throws() { + // aborting doesn't work properly on mock + if (typeof(T) == typeof(MockProviderFactory)) + { + return; + } + const int ExpectedErrorCode = 0xfffffff; using SemaphoreSlim sem = new SemaphoreSlim(0); diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index 35805076eeb962..dd71190078c427 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -109,23 +109,35 @@ await new[] } } - internal Task RunStreamClientServer(Func clientFunction, Func serverFunction, bool bidi, int iterations, int millisecondsTimeout) + internal async Task RunStreamClientServer(Func clientFunction, Func serverFunction, bool bidi, int iterations, int millisecondsTimeout) { - return RunClientServer( + byte[] buffer = new byte[1] { 42 }; + + await RunClientServer( clientFunction: async connection => { await using QuicStream stream = bidi ? connection.OpenBidirectionalStream() : connection.OpenUnidirectionalStream(); + // Open(Bi|Uni)directionalStream only allocates ID. We will force stream opening + // by Writing there and receiving data on the other side. + await stream.WriteAsync(buffer); + await clientFunction(stream); + stream.Shutdown(); await stream.ShutdownCompleted(); }, serverFunction: async connection => { await using QuicStream stream = await connection.AcceptStreamAsync(); + Assert.Equal(1, await stream.ReadAsync(buffer)); + await serverFunction(stream); + stream.Shutdown(); await stream.ShutdownCompleted(); - } + }, + iterations, + millisecondsTimeout ); } From 9ffc1021401ee57ed9d5974eec4c4cbafe05b76e Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 12 Jul 2021 13:01:06 +0200 Subject: [PATCH 3/5] Merge fix --- .../System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index dd71190078c427..b54a6a7c7beee1 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -9,12 +9,15 @@ using System.Threading; using System.Threading.Tasks; using Xunit; +using System.Diagnostics.Tracing; namespace System.Net.Quic.Tests { public abstract class QuicTestBase where T : IQuicImplProviderFactory, new() { + private static readonly byte[] s_ping = Encoding.UTF8.GetBytes("PING"); + private static readonly byte[] s_pong = Encoding.UTF8.GetBytes("PONG"); private static readonly IQuicImplProviderFactory s_factory = new T(); public static QuicImplementationProvider ImplementationProvider { get; } = s_factory.GetProvider(); From b01d6a44e0e64e28f7b9d7fb045fc79800d3fa1c Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 12 Jul 2021 14:18:28 +0200 Subject: [PATCH 4/5] Add read state transitions info --- .../Net/Quic/Implementations/MsQuic/MsQuicStream.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index e42dd6e0e225da..44aef8fe49e58c 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -1411,6 +1411,16 @@ private static uint HandleEventConnectionClose(State state) private static Exception GetConnectionAbortedException(State state) => ThrowHelper.GetConnectionAbortedException(state.ConnectionState.AbortErrorCode); + // Read state transitions: + // + // None --(data arrives in event RECV)-> IndividualReadComplete --(user calls ReadAsync() & completes syncronously)-> None + // None --(user calls ReadAsync() & waits)-> PendingRead --(data arrives in event RECV & completes user's ReadAsync())-> None + // Any non-final state --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)-> ReadsCompleted + // Any non-final state --(event PEER_SEND_ABORT)-> Aborted + // Any non-final state --(user calls AbortRead())-> Aborted + // Any state --(CancellationToken's cancellation for ReadAsync())-> Aborted (TODO: should it be only for non-final as others?) + // Any non-final state --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)-> ConnectionClosed + // Closed - no transitions, set for Unidirectional write-only streams private enum ReadState { /// @@ -1428,6 +1438,8 @@ private enum ReadState /// PendingRead, + // following states are final: + /// /// The peer has gracefully shutdown their sends / our receives; the stream's reads are complete. /// From 93ba9d93d50d0804971ea19f353f73a856558bcf Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Mon, 12 Jul 2021 14:22:48 +0200 Subject: [PATCH 5/5] Fix trailing whitespace --- .../src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 44aef8fe49e58c..05d08f567b62c6 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -1418,7 +1418,7 @@ private static Exception GetConnectionAbortedException(State state) => // Any non-final state --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)-> ReadsCompleted // Any non-final state --(event PEER_SEND_ABORT)-> Aborted // Any non-final state --(user calls AbortRead())-> Aborted - // Any state --(CancellationToken's cancellation for ReadAsync())-> Aborted (TODO: should it be only for non-final as others?) + // Any state --(CancellationToken's cancellation for ReadAsync())-> Aborted (TODO: should it be only for non-final as others?) // Any non-final state --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)-> ConnectionClosed // Closed - no transitions, set for Unidirectional write-only streams private enum ReadState