diff --git a/src/Analyzers/MSTest.Analyzers.CodeFixes/PreferAsyncAssertionFixer.cs b/src/Analyzers/MSTest.Analyzers.CodeFixes/PreferAsyncAssertionFixer.cs index 3ff4437e19..e6922b8aac 100644 --- a/src/Analyzers/MSTest.Analyzers.CodeFixes/PreferAsyncAssertionFixer.cs +++ b/src/Analyzers/MSTest.Analyzers.CodeFixes/PreferAsyncAssertionFixer.cs @@ -46,6 +46,18 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) return; } + if (invocationExpression.Ancestors().OfType().FirstOrDefault() is { } methodDeclaration && + ContainsReturnExpressionInUnawaitableContext(methodDeclaration)) + { + SemanticModel? semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false); + if (semanticModel is not null && + !methodDeclaration.Modifiers.Any(static modifier => modifier.IsKind(SyntaxKind.AsyncKeyword)) && + IsTaskOrValueTaskReturnType(methodDeclaration, semanticModel, context.CancellationToken)) + { + return; + } + } + context.RegisterCodeFix( CodeAction.Create( title: CodeFixResources.UseAsyncAssertionFix, @@ -54,6 +66,21 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) context.Diagnostics); } + private static bool ContainsReturnExpressionInUnawaitableContext(MethodDeclarationSyntax methodDeclaration) + { + var walker = new UnawaitableReturnDetector(); + if (methodDeclaration.Body is { } body) + { + walker.Visit(body); + } + else if (methodDeclaration.ExpressionBody is { } expressionBody) + { + walker.Visit(expressionBody); + } + + return walker.Found; + } + private static async Task UseAsyncAssertionAsync(Document document, InvocationExpressionSyntax invocationExpression, CancellationToken cancellationToken) { DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); @@ -200,6 +227,22 @@ private static bool TryReplaceActionExpression(ExpressionSyntax expression, [Not return true; } + if (expression is ObjectCreationExpressionSyntax objectCreationExpression && + objectCreationExpression.ArgumentList is { Arguments.Count: 1 } objectCreationArgumentList && + TryReplaceActionExpression(objectCreationArgumentList.Arguments[0].Expression, out ExpressionSyntax? objectCreationNewExpression)) + { + newExpression = objectCreationNewExpression.WithTriviaFrom(expression); + return true; + } + + if (expression is ImplicitObjectCreationExpressionSyntax implicitObjectCreationExpression && + implicitObjectCreationExpression.ArgumentList is { Arguments.Count: 1 } implicitObjectCreationArgumentList && + TryReplaceActionExpression(implicitObjectCreationArgumentList.Arguments[0].Expression, out ExpressionSyntax? implicitObjectCreationNewExpression)) + { + newExpression = implicitObjectCreationNewExpression.WithTriviaFrom(expression); + return true; + } + if (expression is not LambdaExpressionSyntax lambdaExpression || !TryGetBlockedTaskExpressionFromLambda(lambdaExpression, out ExpressionSyntax? asyncExpression)) { @@ -437,4 +480,70 @@ private static StatementSyntax[] CreateAwaitAndReturnStatements(ReturnStatementS return includeReturn ? [awaitStatement, newReturnStatement] : [awaitStatement]; } } + + private sealed class UnawaitableReturnDetector : CSharpSyntaxWalker + { + private int _unawaitableDepth; + + public bool Found { get; private set; } + + public override void Visit(SyntaxNode? node) + { + if (Found) + { + return; + } + + base.Visit(node); + } + + public override void VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node) + { + } + + public override void VisitParenthesizedLambdaExpression(ParenthesizedLambdaExpressionSyntax node) + { + } + + public override void VisitAnonymousMethodExpression(AnonymousMethodExpressionSyntax node) + { + } + + public override void VisitLocalFunctionStatement(LocalFunctionStatementSyntax node) + { + } + + public override void VisitLockStatement(LockStatementSyntax node) + => VisitInUnawaitableContext(node.Statement); + + public override void VisitUnsafeStatement(UnsafeStatementSyntax node) + => VisitInUnawaitableContext(node.Block); + + public override void VisitFixedStatement(FixedStatementSyntax node) + => VisitInUnawaitableContext(node.Statement); + + public override void VisitReturnStatement(ReturnStatementSyntax node) + { + if (_unawaitableDepth > 0 && node.Expression is not null) + { + Found = true; + return; + } + + base.VisitReturnStatement(node); + } + + private void VisitInUnawaitableContext(SyntaxNode? node) + { + _unawaitableDepth++; + try + { + Visit(node); + } + finally + { + _unawaitableDepth--; + } + } + } } diff --git a/test/UnitTests/MSTest.Analyzers.UnitTests/PreferAsyncAssertionAnalyzerTests.cs b/test/UnitTests/MSTest.Analyzers.UnitTests/PreferAsyncAssertionAnalyzerTests.cs index d6daf569d5..66fd49586a 100644 --- a/test/UnitTests/MSTest.Analyzers.UnitTests/PreferAsyncAssertionAnalyzerTests.cs +++ b/test/UnitTests/MSTest.Analyzers.UnitTests/PreferAsyncAssertionAnalyzerTests.cs @@ -820,4 +820,253 @@ public async Task MyTestMethod() await VerifyCS.VerifyAnalyzerAsync(code); } + + [TestMethod] + public async Task WhenAssertionActionIsExplicitDelegateCreation_CodeFixUnwrapsDelegateCreation() + { + string code = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + [|Assert.ThrowsExactly(new Action(() => BarAsync().GetAwaiter().GetResult()))|]; + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + string fixedCode = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public async Task MyTestMethod() + { + await Assert.ThrowsExactlyAsync(() => BarAsync()); + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenAssertionActionIsExplicitDelegateCreationWithAnonymousMethod_CodeFixUnwrapsDelegateCreation() + { + string code = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + [|Assert.ThrowsExactly(new Action(delegate { BarAsync().GetAwaiter().GetResult(); }))|]; + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + string fixedCode = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public async Task MyTestMethod() + { + await Assert.ThrowsExactlyAsync(() => BarAsync()); + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenAssertionActionIsTargetTypedDelegateCreation_CodeFixUnwrapsDelegateCreation() + { + string code = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public void MyTestMethod() + { + [|Assert.ThrowsExactly((Action)new(() => BarAsync().GetAwaiter().GetResult()))|]; + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + string fixedCode = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public async Task MyTestMethod() + { + await Assert.ThrowsExactlyAsync(() => BarAsync()); + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + await VerifyCS.VerifyCodeFixAsync(code, fixedCode); + } + + [TestMethod] + public async Task WhenNonAsyncTaskMethodHasReturnInsideLockBlock_DiagnosticReportedButNoCodeFixOffered() + { + string code = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + private readonly object _gate = new(); + + [TestMethod] + public Task MyTestMethod() + { + [|Assert.ThrowsExactly(() => BarAsync().GetAwaiter().GetResult())|]; + lock (_gate) + { + return Task.CompletedTask; + } + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + // The diagnostic is still reported, but the fixer cannot safely transform the + // method because it would emit an 'await' inside the lock body. No code fix is + // offered, so the expected fixed code is identical to the original. + await VerifyCS.VerifyCodeFixAsync(code, code); + } + + [TestMethod] + public async Task WhenNonAsyncValueTaskMethodHasReturnInsideUnsafeBlock_DiagnosticReportedButNoCodeFixOffered() + { + string code = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public ValueTask MyTestMethod() + { + [|Assert.ThrowsExactly(() => BarAsync().GetAwaiter().GetResult())|]; + unsafe + { + return ValueTask.CompletedTask; + } + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + // The diagnostic is still reported, but the fixer cannot safely transform the + // method because it would emit an 'await' inside the unsafe block. No code fix is + // offered, so the expected fixed code is identical to the original. + var test = new VerifyCS.Test + { + TestCode = code, + FixedCode = code, + }; + + test.SolutionTransforms.Add((solution, projectId) => + { + var compilationOptions = (CSharpCompilationOptions)solution.GetProject(projectId)!.CompilationOptions!; + return solution.WithProjectCompilationOptions(projectId, compilationOptions.WithAllowUnsafe(true)); + }); + + await test.RunAsync(CancellationToken.None); + } + + [TestMethod] + public async Task WhenNonAsyncTaskMethodHasReturnInsideFixedBlock_DiagnosticReportedButNoCodeFixOffered() + { + string code = """ + using System; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class MyTestClass + { + [TestMethod] + public Task MyTestMethod() + { + [|Assert.ThrowsExactly(() => BarAsync().GetAwaiter().GetResult())|]; + int[] data = { 1, 2, 3 }; + unsafe + { + fixed (int* p = data) + { + return Task.CompletedTask; + } + } + } + + private Task BarAsync() => Task.CompletedTask; + } + """; + + // The diagnostic is still reported, but the fixer cannot safely transform the + // method because it would emit an 'await' inside the fixed statement body. No code fix is + // offered, so the expected fixed code is identical to the original. + var test = new VerifyCS.Test + { + TestCode = code, + FixedCode = code, + }; + + test.SolutionTransforms.Add((solution, projectId) => + { + var compilationOptions = (CSharpCompilationOptions)solution.GetProject(projectId)!.CompilationOptions!; + return solution.WithProjectCompilationOptions(projectId, compilationOptions.WithAllowUnsafe(true)); + }); + + await test.RunAsync(CancellationToken.None); + } }