Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/Analyzers/MSTest.Analyzers.CodeFixes/PreferAsyncAssertionFixer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
return;
}

if (invocationExpression.Ancestors().OfType<MethodDeclarationSyntax>().FirstOrDefault() is { } methodDeclaration &&
ContainsReturnExpressionInUnawaitableContext(methodDeclaration))
{
SemanticModel? semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
if (semanticModel is not null &&
!methodDeclaration.Modifiers.Any(static modifier => modifier.IsKind(SyntaxKind.AsyncKeyword)) &&
IsTaskOrValueTaskReturnType(methodDeclaration, semanticModel, context.CancellationToken))
{
return;
}
}

context.RegisterCodeFix(
CodeAction.Create(
title: CodeFixResources.UseAsyncAssertionFix,
Expand All @@ -54,6 +66,21 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
context.Diagnostics);
}

private static bool ContainsReturnExpressionInUnawaitableContext(MethodDeclarationSyntax methodDeclaration)
{
var walker = new UnawaitableReturnDetector();
if (methodDeclaration.Body is { } body)
{
walker.Visit(body);
}
else if (methodDeclaration.ExpressionBody is { } expressionBody)
{
walker.Visit(expressionBody);
}

return walker.Found;
}

private static async Task<Document> UseAsyncAssertionAsync(Document document, InvocationExpressionSyntax invocationExpression, CancellationToken cancellationToken)
{
DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -200,6 +227,22 @@ private static bool TryReplaceActionExpression(ExpressionSyntax expression, [Not
return true;
}

if (expression is ObjectCreationExpressionSyntax objectCreationExpression &&
objectCreationExpression.ArgumentList is { Arguments.Count: 1 } objectCreationArgumentList &&
TryReplaceActionExpression(objectCreationArgumentList.Arguments[0].Expression, out ExpressionSyntax? objectCreationNewExpression))
{
newExpression = objectCreationNewExpression.WithTriviaFrom(expression);
Comment thread
Evangelink marked this conversation as resolved.
return true;
}

if (expression is ImplicitObjectCreationExpressionSyntax implicitObjectCreationExpression &&
implicitObjectCreationExpression.ArgumentList is { Arguments.Count: 1 } implicitObjectCreationArgumentList &&
TryReplaceActionExpression(implicitObjectCreationArgumentList.Arguments[0].Expression, out ExpressionSyntax? implicitObjectCreationNewExpression))
{
newExpression = implicitObjectCreationNewExpression.WithTriviaFrom(expression);
return true;
}

if (expression is not LambdaExpressionSyntax lambdaExpression ||
!TryGetBlockedTaskExpressionFromLambda(lambdaExpression, out ExpressionSyntax? asyncExpression))
{
Expand Down Expand Up @@ -437,4 +480,70 @@ private static StatementSyntax[] CreateAwaitAndReturnStatements(ReturnStatementS
return includeReturn ? [awaitStatement, newReturnStatement] : [awaitStatement];
}
}

private sealed class UnawaitableReturnDetector : CSharpSyntaxWalker
{
private int _unawaitableDepth;

public bool Found { get; private set; }

public override void Visit(SyntaxNode? node)
{
if (Found)
{
return;
}

base.Visit(node);
}

public override void VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node)
{
}

public override void VisitParenthesizedLambdaExpression(ParenthesizedLambdaExpressionSyntax node)
{
}

public override void VisitAnonymousMethodExpression(AnonymousMethodExpressionSyntax node)
{
}

public override void VisitLocalFunctionStatement(LocalFunctionStatementSyntax node)
{
}

public override void VisitLockStatement(LockStatementSyntax node)
=> VisitInUnawaitableContext(node.Statement);

public override void VisitUnsafeStatement(UnsafeStatementSyntax node)
=> VisitInUnawaitableContext(node.Block);

public override void VisitFixedStatement(FixedStatementSyntax node)
=> VisitInUnawaitableContext(node.Statement);

public override void VisitReturnStatement(ReturnStatementSyntax node)
{
if (_unawaitableDepth > 0 && node.Expression is not null)
{
Found = true;
return;
}

base.VisitReturnStatement(node);
}

private void VisitInUnawaitableContext(SyntaxNode? node)
{
_unawaitableDepth++;
try
{
Visit(node);
}
finally
{
_unawaitableDepth--;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -820,4 +820,253 @@ public async Task MyTestMethod()

await VerifyCS.VerifyAnalyzerAsync(code);
}

[TestMethod]
public async Task WhenAssertionActionIsExplicitDelegateCreation_CodeFixUnwrapsDelegateCreation()
{
string code = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
[|Assert.ThrowsExactly<InvalidOperationException>(new Action(() => BarAsync().GetAwaiter().GetResult()))|];
}

private Task BarAsync() => Task.CompletedTask;
}
""";

string fixedCode = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public async Task MyTestMethod()
{
await Assert.ThrowsExactlyAsync<InvalidOperationException>(() => BarAsync());
}

private Task BarAsync() => Task.CompletedTask;
}
""";

await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[TestMethod]
public async Task WhenAssertionActionIsExplicitDelegateCreationWithAnonymousMethod_CodeFixUnwrapsDelegateCreation()
{
string code = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
[|Assert.ThrowsExactly<InvalidOperationException>(new Action(delegate { BarAsync().GetAwaiter().GetResult(); }))|];
}

private Task BarAsync() => Task.CompletedTask;
}
""";

string fixedCode = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public async Task MyTestMethod()
{
await Assert.ThrowsExactlyAsync<InvalidOperationException>(() => BarAsync());
}

private Task BarAsync() => Task.CompletedTask;
}
""";

await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[TestMethod]
public async Task WhenAssertionActionIsTargetTypedDelegateCreation_CodeFixUnwrapsDelegateCreation()
{
string code = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void MyTestMethod()
{
[|Assert.ThrowsExactly<InvalidOperationException>((Action)new(() => BarAsync().GetAwaiter().GetResult()))|];
}

private Task BarAsync() => Task.CompletedTask;
}
""";

string fixedCode = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public async Task MyTestMethod()
{
await Assert.ThrowsExactlyAsync<InvalidOperationException>(() => BarAsync());
}

private Task BarAsync() => Task.CompletedTask;
}
""";

await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

[TestMethod]
public async Task WhenNonAsyncTaskMethodHasReturnInsideLockBlock_DiagnosticReportedButNoCodeFixOffered()
{
string code = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
private readonly object _gate = new();

[TestMethod]
public Task MyTestMethod()
{
[|Assert.ThrowsExactly<InvalidOperationException>(() => BarAsync().GetAwaiter().GetResult())|];
lock (_gate)
{
return Task.CompletedTask;
}
}

private Task BarAsync() => Task.CompletedTask;
}
""";

// The diagnostic is still reported, but the fixer cannot safely transform the
// method because it would emit an 'await' inside the lock body. No code fix is
// offered, so the expected fixed code is identical to the original.
await VerifyCS.VerifyCodeFixAsync(code, code);
}

[TestMethod]
public async Task WhenNonAsyncValueTaskMethodHasReturnInsideUnsafeBlock_DiagnosticReportedButNoCodeFixOffered()
{
string code = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public ValueTask MyTestMethod()
{
[|Assert.ThrowsExactly<InvalidOperationException>(() => BarAsync().GetAwaiter().GetResult())|];
unsafe
{
return ValueTask.CompletedTask;
}
}

private Task BarAsync() => Task.CompletedTask;
}
""";

// The diagnostic is still reported, but the fixer cannot safely transform the
// method because it would emit an 'await' inside the unsafe block. No code fix is
// offered, so the expected fixed code is identical to the original.
var test = new VerifyCS.Test
{
TestCode = code,
FixedCode = code,
};

test.SolutionTransforms.Add((solution, projectId) =>
{
var compilationOptions = (CSharpCompilationOptions)solution.GetProject(projectId)!.CompilationOptions!;
return solution.WithProjectCompilationOptions(projectId, compilationOptions.WithAllowUnsafe(true));
});

await test.RunAsync(CancellationToken.None);
}

[TestMethod]
public async Task WhenNonAsyncTaskMethodHasReturnInsideFixedBlock_DiagnosticReportedButNoCodeFixOffered()
{
string code = """
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public Task MyTestMethod()
{
[|Assert.ThrowsExactly<InvalidOperationException>(() => BarAsync().GetAwaiter().GetResult())|];
int[] data = { 1, 2, 3 };
unsafe
{
fixed (int* p = data)
{
return Task.CompletedTask;
}
}
}

private Task BarAsync() => Task.CompletedTask;
}
""";

// The diagnostic is still reported, but the fixer cannot safely transform the
// method because it would emit an 'await' inside the fixed statement body. No code fix is
// offered, so the expected fixed code is identical to the original.
var test = new VerifyCS.Test
{
TestCode = code,
FixedCode = code,
};

test.SolutionTransforms.Add((solution, projectId) =>
{
var compilationOptions = (CSharpCompilationOptions)solution.GetProject(projectId)!.CompilationOptions!;
return solution.WithProjectCompilationOptions(projectId, compilationOptions.WithAllowUnsafe(true));
});

await test.RunAsync(CancellationToken.None);
}
}
Comment thread
Evangelink marked this conversation as resolved.