Skip to content

Commit 53d1d2c

Browse files
authored
Fix corner-case handling of cancellation exception in ForEachAsync (#59065)
* Fix corner-case handling of cancellation exception in ForEachAsync If code in Parallel.ForEachAsync throws OperationCanceledExceptions containing the CancellationToken passed to the iteration and that token has _not_ had cancellation requested (so why are they throwing with it) and there are no other exceptions, the ForEachAsync will effectively hang after failing to complete the task returned from it. The issue stems from how we treat cancellation. If the user-supplied token hasn't been canceled but we have OperationCanceledExceptions for the token passed into the iteration (the "internal" token), it can only have been canceled because an exception occurred. We filter out these cancellation exceptions, leaving just the exceptions that are deemed to have caused the failure in the first place. But the code doesn't currently account for the possibility that the developer is (arguably erroneously) throwing such an OperationCanceledException with the internal cancellation token as that root failure. The fix is to only filter out these OCEs if there are other exceptions besides them. * Stop filtering out cancellation exceptions in Parallel.ForEachAsync
1 parent 3409f73 commit 53d1d2c

2 files changed

Lines changed: 62 additions & 16 deletions

File tree

src/libraries/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.ForEachAsync.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,10 @@ public void Complete()
480480
}
481481
else
482482
{
483-
// Fault with all of the received exceptions, but filter out those due to inner cancellation,
484-
// as they're effectively an implementation detail and stem from the original exception.
485-
Debug.Assert(_exceptions.Count > 0, "If _exceptions was created, it should have also been populated.");
486-
for (int i = 0; i < _exceptions.Count; i++)
487-
{
488-
if (_exceptions[i] is OperationCanceledException oce && oce.CancellationToken == Cancellation.Token)
489-
{
490-
_exceptions[i] = null!;
491-
}
492-
}
493-
_exceptions.RemoveAll(e => e is null);
494-
Debug.Assert(_exceptions.Count > 0, "Since external cancellation wasn't requested, there should have been a non-cancellation exception that triggered internal cancellation.");
483+
// Fail the task with the resulting exceptions. The first should be the initial
484+
// exception that triggered the operation to shut down. The others, if any, may
485+
// include cancellation exceptions from other concurrent operations being canceled
486+
// in response to the primary exception.
495487
taskSet = TrySetException(_exceptions);
496488
}
497489

src/libraries/System.Threading.Tasks.Parallel/tests/ParallelForEachAsyncTests.cs

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,64 @@ static async IAsyncEnumerable<int> Iterate()
618618
Assert.True(t.IsCanceled);
619619
}
620620

621+
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
622+
[InlineData(false)]
623+
[InlineData(true)]
624+
public async Task Cancellation_FaultsForOceForNonCancellation(bool internalToken)
625+
{
626+
static async IAsyncEnumerable<int> Iterate()
627+
{
628+
int counter = 0;
629+
while (true)
630+
{
631+
await Task.Yield();
632+
yield return counter++;
633+
}
634+
}
635+
636+
var cts = new CancellationTokenSource();
637+
638+
Task t = Parallel.ForEachAsync(Iterate(), new ParallelOptions { CancellationToken = cts.Token }, (item, cancellationToken) =>
639+
{
640+
throw new OperationCanceledException(internalToken ? cancellationToken : cts.Token);
641+
});
642+
643+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
644+
Assert.True(t.IsFaulted);
645+
}
646+
647+
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
648+
[InlineData(0, 4)]
649+
[InlineData(1, 4)]
650+
[InlineData(2, 4)]
651+
[InlineData(3, 4)]
652+
[InlineData(4, 4)]
653+
public async Task Cancellation_InternalCancellationExceptionsArentFilteredOut(int numThrowingNonCanceledOce, int total)
654+
{
655+
var cts = new CancellationTokenSource();
656+
657+
var barrier = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
658+
int remainingCount = total;
659+
660+
Task t = Parallel.ForEachAsync(Enumerable.Range(0, total), new ParallelOptions { CancellationToken = cts.Token, MaxDegreeOfParallelism = total }, async (item, cancellationToken) =>
661+
{
662+
// Wait for all operations to be started
663+
if (Interlocked.Decrement(ref remainingCount) == 0)
664+
{
665+
barrier.SetResult();
666+
}
667+
await barrier.Task;
668+
669+
throw item < numThrowingNonCanceledOce ?
670+
new OperationCanceledException(cancellationToken) :
671+
throw new FormatException();
672+
});
673+
674+
await Assert.ThrowsAnyAsync<Exception>(() => t);
675+
Assert.Equal(total, t.Exception.InnerExceptions.Count);
676+
Assert.Equal(numThrowingNonCanceledOce, t.Exception.InnerExceptions.Count(e => e is OperationCanceledException));
677+
}
678+
621679
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
622680
public void Exception_FromGetEnumerator_Sync()
623681
{
@@ -672,7 +730,6 @@ static IEnumerable<int> Iterate()
672730
Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default);
673731
await Assert.ThrowsAsync<FormatException>(() => t);
674732
Assert.True(t.IsFaulted);
675-
Assert.Equal(1, t.Exception.InnerExceptions.Count);
676733
}
677734

678735
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -694,7 +751,6 @@ static async IAsyncEnumerable<int> Iterate()
694751
Task t = Parallel.ForEachAsync(Iterate(), (item, cancellationToken) => default);
695752
await Assert.ThrowsAsync<FormatException>(() => t);
696753
Assert.True(t.IsFaulted);
697-
Assert.Equal(1, t.Exception.InnerExceptions.Count);
698754
}
699755

700756
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -771,7 +827,6 @@ public async Task Exception_FromDispose_Sync()
771827
Task t = Parallel.ForEachAsync((IEnumerable<int>)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default);
772828
await Assert.ThrowsAsync<FormatException>(() => t);
773829
Assert.True(t.IsFaulted);
774-
Assert.Equal(1, t.Exception.InnerExceptions.Count);
775830
}
776831

777832
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
@@ -780,7 +835,6 @@ public async Task Exception_FromDispose_Async()
780835
Task t = Parallel.ForEachAsync((IAsyncEnumerable<int>)new ThrowsExceptionFromDispose(), (item, cancellationToken) => default);
781836
await Assert.ThrowsAsync<DivideByZeroException>(() => t);
782837
Assert.True(t.IsFaulted);
783-
Assert.Equal(1, t.Exception.InnerExceptions.Count);
784838
}
785839

786840
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]

0 commit comments

Comments
 (0)