From 546c010ff44026e83d8526d76fb0bb65d0210299 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:36:48 +0100 Subject: [PATCH 1/2] Optimize argument formatting and caching mechanisms for improved performance --- TUnit.Core/Helpers/ArgumentFormatter.cs | 23 ++- TUnit.Core/TestContext.cs | 16 +- .../Discovery/ReflectionTestDataCollector.cs | 146 +++++++++++------- TUnit.Engine/Scheduling/TestScheduler.cs | 33 ++-- .../Services/TestFilterTypeExtractor.cs | 70 ++++++++- TUnit.Engine/Services/TestGroupingService.cs | 139 +++++++++++++---- 6 files changed, 313 insertions(+), 114 deletions(-) diff --git a/TUnit.Core/Helpers/ArgumentFormatter.cs b/TUnit.Core/Helpers/ArgumentFormatter.cs index eb3ef68a6c..bf674de695 100644 --- a/TUnit.Core/Helpers/ArgumentFormatter.cs +++ b/TUnit.Core/Helpers/ArgumentFormatter.cs @@ -25,7 +25,16 @@ public static string GetConstantValue(TestContext testContext, object? o) public static string FormatArguments(IEnumerable arguments) { - return string.Join(", ", arguments.Select(arg => FormatDefault(arg))); + var list = arguments as IList ?? arguments.ToList(); + if (list.Count == 0) + return string.Empty; + + var formatted = new string[list.Count]; + for (int i = 0; i < list.Count; i++) + { + formatted[i] = FormatDefault(list[i]); + } + return string.Join(", ", formatted); } private static string FormatDefault(object? o) @@ -70,15 +79,19 @@ private static string FormatDefault(object? o) private static string FormatTuple(object tuple) { var elements = TupleHelper.UnwrapTuple(tuple); - var formattedElements = elements.Select(e => FormatDefault(e)); - return $"({string.Join(", ", formattedElements)})"; + var formatted = new string[elements.Length]; + for (int i = 0; i < elements.Length; i++) + { + formatted[i] = FormatDefault(elements[i]); + } + return $"({string.Join(", ", formatted)})"; } private static string FormatEnumerable(IEnumerable enumerable) { - var elements = new List(); + const int maxElements = 10; + var elements = new List(maxElements + 1); var count = 0; - const int maxElements = 10; // Limit to prevent huge displays foreach (var element in enumerable) { diff --git a/TUnit.Core/TestContext.cs b/TUnit.Core/TestContext.cs index 416fdbbf23..bbba3b183a 100644 --- a/TUnit.Core/TestContext.cs +++ b/TUnit.Core/TestContext.cs @@ -18,6 +18,7 @@ namespace TUnit.Core; public class TestContext : Context { private readonly TestBuilderContext _testBuilderContext; + private string? _cachedDisplayName; public TestContext(string testName, IServiceProvider serviceProvider, ClassHookContext classContext, TestBuilderContext testBuilderContext, CancellationToken cancellationToken) : base(classContext) { @@ -196,15 +197,28 @@ public string GetDisplayName() return CustomDisplayName!; } + if (_cachedDisplayName != null) + { + return _cachedDisplayName; + } + + if (TestDetails.TestMethodArguments.Length == 0) + { + _cachedDisplayName = TestName; + return TestName; + } + var arguments = string.Join(", ", TestDetails.TestMethodArguments .Select(arg => ArgumentFormatter.Format(arg, ArgumentDisplayFormatters))); if (string.IsNullOrEmpty(arguments)) { + _cachedDisplayName = TestName; return TestName; } - return $"{TestName}({arguments})"; + _cachedDisplayName = $"{TestName}({arguments})"; + return _cachedDisplayName; } public Dictionary ObjectBag => _testBuilderContext.ObjectBag; diff --git a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs index e556bb71cd..902942ed42 100644 --- a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs +++ b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs @@ -22,10 +22,60 @@ internal sealed class ReflectionTestDataCollector : ITestDataCollector private static readonly Lock _resultsLock = new(); // Only for final results aggregation private static readonly ConcurrentDictionary _assemblyTypesCache = new(); private static readonly ConcurrentDictionary _typeMethodsCache = new(); + + private static Assembly[]? _cachedAssemblies; + private static readonly Lock _assemblyCacheLock = new(); + + private static Assembly[] GetCachedAssemblies() + { + lock (_assemblyCacheLock) + { + return _cachedAssemblies ??= AppDomain.CurrentDomain.GetAssemblies(); + } + } + + public static void ClearCaches() + { + _scannedAssemblies.Clear(); + while (_discoveredTests.TryTake(out _)) { } + _assemblyTypesCache.Clear(); + _typeMethodsCache.Clear(); + lock (_assemblyCacheLock) + { + _cachedAssemblies = null; + } + } + + private async Task> ProcessAssemblyAsync(Assembly assembly, SemaphoreSlim semaphore) + { + await semaphore.WaitAsync().ConfigureAwait(false); + try + { + if (!_scannedAssemblies.TryAdd(assembly, true)) + { + return []; + } + + try + { + return await DiscoverTestsInAssembly(assembly).ConfigureAwait(false); + } + catch (Exception ex) + { + // Create a failed test metadata for the assembly that couldn't be scanned + var failedTest = CreateFailedTestMetadataForAssembly(assembly, ex); + return [failedTest]; + } + } + finally + { + semaphore.Release(); + } + } public async Task> CollectTestsAsync(string testSessionId) { - var allAssemblies = AppDomain.CurrentDomain.GetAssemblies(); + var allAssemblies = GetCachedAssemblies(); var assembliesList = new List(allAssemblies.Length); foreach (var assembly in allAssemblies) { @@ -46,44 +96,14 @@ public async Task> CollectTestsAsync(string testSessio var assembly = assemblies[i]; var index = i; - tasks[index] = Task.Run(async () => - { - await semaphore.WaitAsync().ConfigureAwait(false); - try - { - if (!_scannedAssemblies.TryAdd(assembly, true)) - { - return - [ - ]; - } - - try - { - return await DiscoverTestsInAssembly(assembly).ConfigureAwait(false); - } - catch (Exception ex) - { - // Create a failed test metadata for the assembly that couldn't be scanned - var failedTest = CreateFailedTestMetadataForAssembly(assembly, ex); - return - [ - failedTest - ]; - } - } - finally - { - semaphore.Release(); - } - }); + tasks[index] = ProcessAssemblyAsync(assembly, semaphore); } // Wait for all tasks to complete var results = await Task.WhenAll(tasks).ConfigureAwait(false); - // Reassemble results in original order - var newTests = new List(); + var totalCount = results.Sum(r => r.Count); + var newTests = new List(totalCount); foreach (var tests in results) { newTests.AddRange(tests); @@ -111,7 +131,7 @@ public async IAsyncEnumerable CollectTestsStreamingAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Get assemblies to scan - var allAssemblies = AppDomain.CurrentDomain.GetAssemblies(); + var allAssemblies = GetCachedAssemblies(); var assemblies = new List(allAssemblies.Length); foreach (var assembly in allAssemblies) { @@ -154,7 +174,7 @@ private static IEnumerable GetAllTestMethods(Type type) { return _typeMethodsCache.GetOrAdd(type, static t => { - var methods = new List(); + var methods = new List(20); var currentType = t; while (currentType != null && currentType != typeof(object)) @@ -276,8 +296,18 @@ private static bool ShouldScanAssembly(Assembly assembly) // Don't return false here, continue with other checks } - if (!assembly.GetReferencedAssemblies().Any(a => - a.Name != null && (a.Name.StartsWith("TUnit") || a.Name == "TUnit"))) + var referencedAssemblies = assembly.GetReferencedAssemblies(); + var hasTUnitReference = false; + foreach (var reference in referencedAssemblies) + { + if (reference.Name != null && (reference.Name.StartsWith("TUnit") || reference.Name == "TUnit")) + { + hasTUnitReference = true; + break; + } + } + + if (!hasTUnitReference) { return false; } @@ -293,7 +323,7 @@ private static bool ShouldScanAssembly(Assembly assembly) Justification = "Reflection mode requires dynamic access")] private static async Task> DiscoverTestsInAssembly(Assembly assembly) { - var discoveredTests = new List(); + var discoveredTests = new List(100); var types = _assemblyTypesCache.GetOrAdd(assembly, asm => { @@ -541,7 +571,7 @@ private static async IAsyncEnumerable DiscoverTestsInAssemblyStrea Justification = "Reflection mode requires dynamic access")] private static async Task> DiscoverGenericTests(Type genericTypeDefinition) { - var discoveredTests = new List(); + var discoveredTests = new List(100); // Extract class-level data sources that will determine the generic type arguments var classDataSources = ReflectionAttributeExtractor.ExtractDataSources(genericTypeDefinition); @@ -765,7 +795,7 @@ private static async IAsyncEnumerable DiscoverGenericTestsStreamin private static async Task> GetDataFromSourceAsync(IDataSourceAttribute dataSource, MethodMetadata methodMetadata) { - var data = new List(); + var data = new List(16); try { @@ -1469,7 +1499,7 @@ private static bool IsCovariantCompatible(Type paramType, Type argType) private async Task> DiscoverDynamicTests(string testSessionId) { - var dynamicTests = new List(); + var dynamicTests = new List(50); // First check if there are any registered dynamic test sources from source generation if (Sources.DynamicTestSources.Count > 0) @@ -1494,9 +1524,7 @@ private async Task> DiscoverDynamicTests(string testSessionId } } - // Also discover dynamic test builder methods via reflection - // Optimize: Pre-filter and allocate array instead of LINQ ToList() - var allAssemblies = AppDomain.CurrentDomain.GetAssemblies(); + var allAssemblies = GetCachedAssemblies(); var assembliesList = new List(allAssemblies.Length); foreach (var assembly in allAssemblies) { @@ -1521,11 +1549,12 @@ private async Task> DiscoverDynamicTests(string testSessionId } }); - foreach (var type in types.Where(t => t.IsClass && !IsCompilerGenerated(t))) + foreach (var type in types) { - // Optimize: Manual filtering instead of LINQ Where().ToArray() + if (!type.IsClass || IsCompilerGenerated(type)) + continue; var declaredMethods = type.GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.DeclaredOnly); - var methodsList = new List(declaredMethods.Length); + var methodsList = new List(4); foreach (var method in declaredMethods) { #pragma warning disable TUnitWIP0001 @@ -1561,9 +1590,15 @@ private async IAsyncEnumerable DiscoverDynamicTestsStreamingAsync( string testSessionId, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - var assemblies = AppDomain.CurrentDomain.GetAssemblies() - .Where(ShouldScanAssembly) - .ToList(); + var allAssemblies = GetCachedAssemblies(); + var assemblies = new List(allAssemblies.Length); + foreach (var assembly in allAssemblies) + { + if (ShouldScanAssembly(assembly)) + { + assemblies.Add(assembly); + } + } foreach (var assembly in assemblies) { @@ -1581,11 +1616,12 @@ private async IAsyncEnumerable DiscoverDynamicTestsStreamingAsync( } }); - foreach (var type in types.Where(t => t.IsClass && !IsCompilerGenerated(t))) + foreach (var type in types) { - // Optimize: Manual filtering instead of LINQ Where().ToArray() + if (!type.IsClass || IsCompilerGenerated(type)) + continue; var declaredMethods = type.GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.DeclaredOnly); - var methodsList = new List(declaredMethods.Length); + var methodsList = new List(4); foreach (var method in declaredMethods) { #pragma warning disable TUnitWIP0001 @@ -1613,7 +1649,7 @@ private async IAsyncEnumerable DiscoverDynamicTestsStreamingAsync( private async Task> ExecuteDynamicTestBuilder(Type testClass, MethodInfo builderMethod, string testSessionId) { - var dynamicTests = new List(); + var dynamicTests = new List(50); // Extract file path and line number from the DynamicTestBuilderAttribute if possible var filePath = ExtractFilePath(builderMethod) ?? "Unknown"; diff --git a/TUnit.Engine/Scheduling/TestScheduler.cs b/TUnit.Engine/Scheduling/TestScheduler.cs index 8d05ddbaa3..4080ce9665 100644 --- a/TUnit.Engine/Scheduling/TestScheduler.cs +++ b/TUnit.Engine/Scheduling/TestScheduler.cs @@ -308,6 +308,23 @@ private async Task ExecuteSequentiallyAsync( } } + private async Task ProcessTestQueueAsync( + System.Collections.Concurrent.ConcurrentQueue testQueue, + CancellationToken cancellationToken) + { + while (testQueue.TryDequeue(out var test)) + { + if (cancellationToken.IsCancellationRequested) + { + break; + } + + var task = ExecuteTestWithParallelLimitAsync(test, cancellationToken); + test.ExecutionTask = task; + await task.ConfigureAwait(false); + } + } + private async Task ExecuteParallelTestsWithLimitAsync( AbstractExecutableTest[] tests, int maxParallelism, @@ -317,23 +334,9 @@ private async Task ExecuteParallelTestsWithLimitAsync( var testQueue = new System.Collections.Concurrent.ConcurrentQueue(tests); var workers = new Task[maxParallelism]; - // Create worker tasks that will process tests from the queue for (var i = 0; i < maxParallelism; i++) { - workers[i] = Task.Run(async () => - { - while (testQueue.TryDequeue(out var test)) - { - if (cancellationToken.IsCancellationRequested) - { - break; - } - - var task = ExecuteTestWithParallelLimitAsync(test, cancellationToken); - test.ExecutionTask = task; - await task.ConfigureAwait(false); - } - }, cancellationToken); + workers[i] = ProcessTestQueueAsync(testQueue, cancellationToken); } await WaitForTasksWithFailFastHandling(workers, cancellationToken).ConfigureAwait(false); diff --git a/TUnit.Engine/Services/TestFilterTypeExtractor.cs b/TUnit.Engine/Services/TestFilterTypeExtractor.cs index 4069530c8f..3eeb881424 100644 --- a/TUnit.Engine/Services/TestFilterTypeExtractor.cs +++ b/TUnit.Engine/Services/TestFilterTypeExtractor.cs @@ -10,6 +10,55 @@ internal static class TestFilterTypeExtractor { private static readonly Regex PathFilterRegex = new(@"^/([^/]+)/([^/]+)/([^/]+)(?:/|$)", RegexOptions.Compiled); + private static readonly Lazy>> AssemblyCache = + new(() => BuildAssemblyCache()); + + private static readonly Lazy> TypeCache = + new(() => BuildTypeCache()); + + private static Dictionary> BuildAssemblyCache() + { + var cache = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + var name = assembly.GetName().Name; + if (name != null) + { + if (!cache.TryGetValue(name, out var list)) + { + list = new List(); + cache[name] = list; + } + list.Add(assembly); + } + } + return cache; + } + + private static Dictionary BuildTypeCache() + { + var cache = new Dictionary(StringComparer.Ordinal); + foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + try + { +#pragma warning disable IL2026 + foreach (var type in assembly.GetExportedTypes()) +#pragma warning restore IL2026 + { + if (type.FullName != null) + { + cache[type.FullName] = type; + } + } + } + catch + { + } + } + return cache; + } + public static HashSet? ExtractTypesFromFilter(ITestExecutionFilter? filter) { if (filter == null) @@ -57,17 +106,24 @@ internal static class TestFilterTypeExtractor var fullTypeName = $"{namespaceName}.{className}"; - // Try to find the type in loaded assemblies - foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + if (TypeCache.Value.TryGetValue(fullTypeName, out var cachedType)) { - if (assembly.GetName().Name == assemblyName || assemblyName == "*") + types.Add(cachedType); + } + else if (assemblyName != "*") + { + if (AssemblyCache.Value.TryGetValue(assemblyName, out var assemblies)) { + foreach (var assembly in assemblies) + { #pragma warning disable IL2026 // Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access - var type = assembly.GetType(fullTypeName, throwOnError: false); + var type = assembly.GetType(fullTypeName, throwOnError: false); #pragma warning restore IL2026 - if (type != null) - { - types.Add(type); + if (type != null) + { + types.Add(type); + break; + } } } } diff --git a/TUnit.Engine/Services/TestGroupingService.cs b/TUnit.Engine/Services/TestGroupingService.cs index fc0fb062ed..4ac3655d5b 100644 --- a/TUnit.Engine/Services/TestGroupingService.cs +++ b/TUnit.Engine/Services/TestGroupingService.cs @@ -14,12 +14,49 @@ internal interface ITestGroupingService internal sealed class TestGroupingService : ITestGroupingService { + private struct TestSortKey + { + public int ExecutionPriority { get; init; } + public string? ClassFullName { get; init; } + public int NotInParallelOrder { get; init; } + public NotInParallelConstraint? NotInParallelConstraint { get; init; } + } + public ValueTask GroupTestsByConstraintsAsync(IEnumerable tests) { - var orderedTests = tests - .OrderByDescending(t => t.Context.ExecutionPriority) - .ThenBy(x => x.Context.ClassContext?.ClassType?.FullName ?? string.Empty) - .ThenBy(t => t.Context.ParallelConstraints.OfType().FirstOrDefault()?.Order ?? int.MaxValue); + var testsWithKeys = new List<(AbstractExecutableTest Test, TestSortKey Key)>(); + foreach (var test in tests) + { + NotInParallelConstraint? notInParallelConstraint = null; + foreach (var constraint in test.Context.ParallelConstraints) + { + if (constraint is NotInParallelConstraint nip) + { + notInParallelConstraint = nip; + break; + } + } + + var key = new TestSortKey + { + ExecutionPriority = (int)test.Context.ExecutionPriority, + ClassFullName = test.Context.ClassContext?.ClassType?.FullName, + NotInParallelOrder = notInParallelConstraint?.Order ?? int.MaxValue, + NotInParallelConstraint = notInParallelConstraint + }; + testsWithKeys.Add((test, key)); + } + + testsWithKeys.Sort((a, b) => + { + var priorityCompare = b.Key.ExecutionPriority.CompareTo(a.Key.ExecutionPriority); + if (priorityCompare != 0) return priorityCompare; + + var classCompare = string.CompareOrdinal(a.Key.ClassFullName ?? string.Empty, b.Key.ClassFullName ?? string.Empty); + if (classCompare != 0) return classCompare; + + return a.Key.NotInParallelOrder.CompareTo(b.Key.NotInParallelOrder); + }); var notInParallelList = new List<(AbstractExecutableTest Test, TestPriority Priority)>(); var keyedNotInParallelList = new List<(AbstractExecutableTest Test, IReadOnlyList ConstraintKeys, TestPriority Priority)>(); @@ -27,14 +64,19 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable>>(); var constrainedParallelGroups = new Dictionary Unconstrained, List<(AbstractExecutableTest, IReadOnlyList, TestPriority)> Keyed)>(); - // Process each class group sequentially to maintain class ordering for NotInParallel tests - foreach (var test in orderedTests) + foreach (var (test, sortKey) in testsWithKeys) { var constraints = test.Context.ParallelConstraints; - - // Handle tests with multiple constraints - var parallelGroup = constraints.OfType().FirstOrDefault(); - var notInParallel = constraints.OfType().FirstOrDefault(); + ParallelGroupConstraint? parallelGroup = null; + foreach (var constraint in constraints) + { + if (constraint is ParallelGroupConstraint pg) + { + parallelGroup = pg; + break; + } + } + var notInParallel = sortKey.NotInParallelConstraint; if (parallelGroup != null && notInParallel != null) { @@ -58,22 +100,44 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable t.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty) - .ThenByDescending(t => t.Priority.Priority) - .ThenBy(t => t.Priority.Order) - .Select(t => t.Test) - .ToArray(); + notInParallelList.Sort((a, b) => + { + var classA = a.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; + var classB = b.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; + var classCompare = string.CompareOrdinal(classA, classB); + if (classCompare != 0) return classCompare; + + var priorityCompare = b.Priority.Priority.CompareTo(a.Priority.Priority); + if (priorityCompare != 0) return priorityCompare; + + return a.Priority.Order.CompareTo(b.Priority.Order); + }); + + var sortedNotInParallel = new AbstractExecutableTest[notInParallelList.Count]; + for (int i = 0; i < notInParallelList.Count; i++) + { + sortedNotInParallel[i] = notInParallelList[i].Test; + } - // Sort keyed tests similarly - class grouping first, then priority - var keyedArrays = keyedNotInParallelList - .OrderBy(t => t.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty) - .ThenByDescending(t => t.Priority.Priority) - .ThenBy(t => t.Priority.Order) - .Select(t => (t.Test, t.ConstraintKeys, t.Priority.GetHashCode())) - .ToArray(); + keyedNotInParallelList.Sort((a, b) => + { + var classA = a.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; + var classB = b.Test.Context.ClassContext?.ClassType?.FullName ?? string.Empty; + var classCompare = string.CompareOrdinal(classA, classB); + if (classCompare != 0) return classCompare; + + var priorityCompare = b.Priority.Priority.CompareTo(a.Priority.Priority); + if (priorityCompare != 0) return priorityCompare; + + return a.Priority.Order.CompareTo(b.Priority.Order); + }); + + var keyedArrays = new (AbstractExecutableTest, IReadOnlyList, int)[keyedNotInParallelList.Count]; + for (int i = 0; i < keyedNotInParallelList.Count; i++) + { + var item = keyedNotInParallelList[i]; + keyedArrays[i] = (item.Test, item.ConstraintKeys, item.Priority.GetHashCode()); + } // Convert constrained parallel groups to the final format var finalConstrainedGroups = new Dictionary(); @@ -83,12 +147,25 @@ public ValueTask GroupTestsByConstraintsAsync(IEnumerable t.Item1.Context.ClassContext?.ClassType?.FullName ?? string.Empty) - .ThenByDescending(t => t.Item3.Priority) - .ThenBy(t => t.Item3.Order) - .Select(t => (t.Item1, t.Item2, t.Item3.GetHashCode())) - .ToArray(); + keyed.Sort((a, b) => + { + var classA = a.Item1.Context.ClassContext?.ClassType?.FullName ?? string.Empty; + var classB = b.Item1.Context.ClassContext?.ClassType?.FullName ?? string.Empty; + var classCompare = string.CompareOrdinal(classA, classB); + if (classCompare != 0) return classCompare; + + var priorityCompare = b.Item3.Priority.CompareTo(a.Item3.Priority); + if (priorityCompare != 0) return priorityCompare; + + return a.Item3.Order.CompareTo(b.Item3.Order); + }); + + var sortedKeyed = new (AbstractExecutableTest, IReadOnlyList, int)[keyed.Count]; + for (int i = 0; i < keyed.Count; i++) + { + var item = keyed[i]; + sortedKeyed[i] = (item.Item1, item.Item2, item.Item3.GetHashCode()); + } finalConstrainedGroups[groupName] = new GroupedConstrainedTests { From eecc4964a9c87ed69f1c09dbec8598bc8ff4ecf9 Mon Sep 17 00:00:00 2001 From: Tom Longhurst <30480171+thomhurst@users.noreply.github.com> Date: Sat, 20 Sep 2025 11:50:49 +0100 Subject: [PATCH 2/2] Refactor test data collection to remove filter types and simplify data collector initialization --- .../Collectors/AotTestDataCollector.cs | 7 - TUnit.Engine/Building/TestBuilderPipeline.cs | 16 +-- .../Building/TestDataCollectorFactory.cs | 4 +- .../Discovery/ReflectionTestDataCollector.cs | 2 +- .../Framework/TUnitServiceProvider.cs | 16 +-- .../Services/TestFilterTypeExtractor.cs | 135 ------------------ TUnit.Engine/TestDiscoveryService.cs | 13 +- 7 files changed, 16 insertions(+), 177 deletions(-) delete mode 100644 TUnit.Engine/Services/TestFilterTypeExtractor.cs diff --git a/TUnit.Engine/Building/Collectors/AotTestDataCollector.cs b/TUnit.Engine/Building/Collectors/AotTestDataCollector.cs index 901403a7fd..81c663447b 100644 --- a/TUnit.Engine/Building/Collectors/AotTestDataCollector.cs +++ b/TUnit.Engine/Building/Collectors/AotTestDataCollector.cs @@ -15,17 +15,10 @@ namespace TUnit.Engine.Building.Collectors; /// internal sealed class AotTestDataCollector : ITestDataCollector { - private readonly HashSet? _filterTypes; - - public AotTestDataCollector(HashSet? filterTypes) - { - _filterTypes = filterTypes; - } public async Task> CollectTestsAsync(string testSessionId) { // Stream from all test sources var testSources = Sources.TestSources - .Where(kvp => _filterTypes == null || _filterTypes.Contains(kvp.Key)) .SelectMany(kvp => kvp.Value); var standardTestMetadatas = await testSources diff --git a/TUnit.Engine/Building/TestBuilderPipeline.cs b/TUnit.Engine/Building/TestBuilderPipeline.cs index 5c20474511..39d0307a2c 100644 --- a/TUnit.Engine/Building/TestBuilderPipeline.cs +++ b/TUnit.Engine/Building/TestBuilderPipeline.cs @@ -10,18 +10,18 @@ namespace TUnit.Engine.Building; internal sealed class TestBuilderPipeline { - private readonly Func?, ITestDataCollector> _dataCollectorFactory; + private readonly ITestDataCollector _dataCollector; private readonly ITestBuilder _testBuilder; private readonly IContextProvider _contextProvider; private readonly EventReceiverOrchestrator _eventReceiverOrchestrator; public TestBuilderPipeline( - Func?, ITestDataCollector> dataCollectorFactory, + ITestDataCollector dataCollector, ITestBuilder testBuilder, IContextProvider contextBuilder, EventReceiverOrchestrator eventReceiverOrchestrator) { - _dataCollectorFactory = dataCollectorFactory ?? throw new ArgumentNullException(nameof(dataCollectorFactory)); + _dataCollector = dataCollector ?? throw new ArgumentNullException(nameof(dataCollector)); _testBuilder = testBuilder ?? throw new ArgumentNullException(nameof(testBuilder)); _contextProvider = contextBuilder; _eventReceiverOrchestrator = eventReceiverOrchestrator ?? throw new ArgumentNullException(nameof(eventReceiverOrchestrator)); @@ -54,10 +54,9 @@ private TestBuilderContext CreateTestBuilderContext(TestMetadata metadata) return testBuilderContext; } - public async Task> BuildTestsAsync(string testSessionId, HashSet? filterTypes) + public async Task> BuildTestsAsync(string testSessionId) { - var dataCollector = _dataCollectorFactory(filterTypes); - var collectedMetadata = await dataCollector.CollectTestsAsync(testSessionId).ConfigureAwait(false); + var collectedMetadata = await _dataCollector.CollectTestsAsync(testSessionId).ConfigureAwait(false); return await BuildTestsFromMetadataAsync(collectedMetadata).ConfigureAwait(false); } @@ -67,14 +66,11 @@ public async Task> BuildTestsAsync(string te /// public async Task> BuildTestsStreamingAsync( string testSessionId, - HashSet? filterTypes, CancellationToken cancellationToken = default) { - var dataCollector = _dataCollectorFactory(filterTypes); - // Get metadata streaming if supported // Fall back to non-streaming collection - var collectedMetadata = await dataCollector.CollectTestsAsync(testSessionId).ConfigureAwait(false); + var collectedMetadata = await _dataCollector.CollectTestsAsync(testSessionId).ConfigureAwait(false); return await collectedMetadata .SelectManyAsync(BuildTestsFromSingleMetadataAsync, cancellationToken: cancellationToken) diff --git a/TUnit.Engine/Building/TestDataCollectorFactory.cs b/TUnit.Engine/Building/TestDataCollectorFactory.cs index af89e8b761..d4d69e343e 100644 --- a/TUnit.Engine/Building/TestDataCollectorFactory.cs +++ b/TUnit.Engine/Building/TestDataCollectorFactory.cs @@ -17,7 +17,7 @@ public static ITestDataCollector Create(bool? useSourceGeneration = null, Assemb if (isSourceGenerationEnabled) { - return new AotTestDataCollector(filterTypes: null); + return new AotTestDataCollector(); } else { @@ -31,7 +31,7 @@ public static ITestDataCollector Create(bool? useSourceGeneration = null, Assemb public static async Task CreateAutoDetectAsync(string testSessionId, Assembly[]? assembliesToScan = null) { // Try AOT mode first (check if any tests were registered) - var aotCollector = new AotTestDataCollector(filterTypes: null); + var aotCollector = new AotTestDataCollector(); var aotTests = await aotCollector.CollectTestsAsync(testSessionId); if (aotTests.Any()) diff --git a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs index 902942ed42..ac6af632c6 100644 --- a/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs +++ b/TUnit.Engine/Discovery/ReflectionTestDataCollector.cs @@ -568,7 +568,7 @@ private static async IAsyncEnumerable DiscoverTestsInAssemblyStrea Justification = "Reflection mode requires dynamic access")] [UnconditionalSuppressMessage("Trimming", "IL2067:'type' argument does not satisfy 'DynamicallyAccessedMemberTypes.PublicParameterlessConstructor' in call to 'System.Activator.CreateInstance(Type)'", - Justification = "Reflection mode requires dynamic access")] + Justification = "Reflection mode requires dynamic access")] private static async Task> DiscoverGenericTests(Type genericTypeDefinition) { var discoveredTests = new List(100); diff --git a/TUnit.Engine/Framework/TUnitServiceProvider.cs b/TUnit.Engine/Framework/TUnitServiceProvider.cs index 8e5404ff8a..9313bba6b9 100644 --- a/TUnit.Engine/Framework/TUnitServiceProvider.cs +++ b/TUnit.Engine/Framework/TUnitServiceProvider.cs @@ -125,17 +125,9 @@ public TUnitServiceProvider(IExtension extension, var useSourceGeneration = GetUseSourceGeneration(CommandLineOptions); #pragma warning disable IL2026 // Using member which has 'RequiresUnreferencedCodeAttribute' #pragma warning disable IL3050 // Using member which has 'RequiresDynamicCodeAttribute' - Func?, ITestDataCollector> dataCollectorFactory = filterTypes => - { - if (useSourceGeneration) - { - return new AotTestDataCollector(filterTypes); - } - else - { - return new ReflectionTestDataCollector(); - } - }; + ITestDataCollector dataCollector = useSourceGeneration + ? new AotTestDataCollector() + : new ReflectionTestDataCollector(); #pragma warning restore IL3050 #pragma warning restore IL2026 @@ -144,7 +136,7 @@ public TUnitServiceProvider(IExtension extension, TestBuilderPipeline = Register( new TestBuilderPipeline( - dataCollectorFactory, + dataCollector, testBuilder, ContextProvider, EventReceiverOrchestrator)); diff --git a/TUnit.Engine/Services/TestFilterTypeExtractor.cs b/TUnit.Engine/Services/TestFilterTypeExtractor.cs deleted file mode 100644 index 3eeb881424..0000000000 --- a/TUnit.Engine/Services/TestFilterTypeExtractor.cs +++ /dev/null @@ -1,135 +0,0 @@ -using System.Text.RegularExpressions; -using Microsoft.Testing.Platform.Requests; - -namespace TUnit.Engine.Services; - -/// -/// Extracts test class types from test filters to enable selective test discovery -/// -internal static class TestFilterTypeExtractor -{ - private static readonly Regex PathFilterRegex = new(@"^/([^/]+)/([^/]+)/([^/]+)(?:/|$)", RegexOptions.Compiled); - - private static readonly Lazy>> AssemblyCache = - new(() => BuildAssemblyCache()); - - private static readonly Lazy> TypeCache = - new(() => BuildTypeCache()); - - private static Dictionary> BuildAssemblyCache() - { - var cache = new Dictionary>(StringComparer.OrdinalIgnoreCase); - foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) - { - var name = assembly.GetName().Name; - if (name != null) - { - if (!cache.TryGetValue(name, out var list)) - { - list = new List(); - cache[name] = list; - } - list.Add(assembly); - } - } - return cache; - } - - private static Dictionary BuildTypeCache() - { - var cache = new Dictionary(StringComparer.Ordinal); - foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) - { - try - { -#pragma warning disable IL2026 - foreach (var type in assembly.GetExportedTypes()) -#pragma warning restore IL2026 - { - if (type.FullName != null) - { - cache[type.FullName] = type; - } - } - } - catch - { - } - } - return cache; - } - - public static HashSet? ExtractTypesFromFilter(ITestExecutionFilter? filter) - { - if (filter == null) - { - return null; - } - - return filter switch - { -#pragma warning disable TPEXP - NopFilter => null, - TestNodeUidListFilter => null, // UIDs don't contain type info in a parseable way - TreeNodeFilter treeNodeFilter => ExtractTypesFromTreeFilter(treeNodeFilter.Filter), -#pragma warning restore TPEXP - _ => null - }; - } - - private static HashSet? ExtractTypesFromTreeFilter(string? filterExpression) - { - if (string.IsNullOrWhiteSpace(filterExpression)) - { - return null; - } - - var types = new HashSet(); - - // TreeNodeFilter uses path-based filtering like: /AssemblyName/Namespace/ClassName/MethodName - // Extract class names from the filter - var matches = PathFilterRegex.Matches(filterExpression); - - foreach (Match match in matches) - { - if (match.Success && match.Groups.Count >= 4) - { - var assemblyName = match.Groups[1].Value; - var namespaceName = match.Groups[2].Value; - var className = match.Groups[3].Value; - - // Skip wildcards - if (assemblyName == "*" || namespaceName == "*" || className == "*") - { - continue; - } - - var fullTypeName = $"{namespaceName}.{className}"; - - if (TypeCache.Value.TryGetValue(fullTypeName, out var cachedType)) - { - types.Add(cachedType); - } - else if (assemblyName != "*") - { - if (AssemblyCache.Value.TryGetValue(assemblyName, out var assemblies)) - { - foreach (var assembly in assemblies) - { -#pragma warning disable IL2026 // Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access - var type = assembly.GetType(fullTypeName, throwOnError: false); -#pragma warning restore IL2026 - if (type != null) - { - types.Add(type); - break; - } - } - } - } - } - } - - return types.Count > 0 ? types : null; - } -} \ No newline at end of file diff --git a/TUnit.Engine/TestDiscoveryService.cs b/TUnit.Engine/TestDiscoveryService.cs index 05e0cfb199..bb707e77e8 100644 --- a/TUnit.Engine/TestDiscoveryService.cs +++ b/TUnit.Engine/TestDiscoveryService.cs @@ -56,15 +56,12 @@ public async Task DiscoverTests(string testSessionId, ITest contextProvider.BeforeTestDiscoveryContext.RestoreExecutionContext(); - // Extract types from filter for optimized discovery - var filterTypes = TestFilterTypeExtractor.ExtractTypesFromFilter(filter); - // Stage 1: Stream independent tests immediately while buffering dependent tests var independentTests = new List(); var dependentTests = new List(); var allTests = new List(); - await foreach (var test in DiscoverTestsStreamAsync(testSessionId, filterTypes, cancellationToken).ConfigureAwait(false)) + await foreach (var test in DiscoverTestsStreamAsync(testSessionId, cancellationToken).ConfigureAwait(false)) { allTests.Add(test); @@ -135,7 +132,6 @@ public async Task DiscoverTests(string testSessionId, ITest /// Streams test discovery for parallel discovery and execution private async IAsyncEnumerable DiscoverTestsStreamAsync( string testSessionId, - HashSet? filterTypes, [EnumeratorCancellation] CancellationToken cancellationToken = default) { using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); @@ -147,7 +143,7 @@ private async IAsyncEnumerable DiscoverTestsStreamAsync( cts.CancelAfter(DiscoveryConfiguration.DiscoveryTimeout); } - var tests = await _testBuilderPipeline.BuildTestsStreamingAsync(testSessionId, filterTypes, cancellationToken).ConfigureAwait(false); + var tests = await _testBuilderPipeline.BuildTestsStreamingAsync(testSessionId, cancellationToken).ConfigureAwait(false); foreach (var test in tests) { @@ -170,12 +166,9 @@ public async IAsyncEnumerable DiscoverTestsFullyStreamin { await _testExecutor.ExecuteBeforeTestDiscoveryHooksAsync(cancellationToken).ConfigureAwait(false); - // Extract types from filter for optimized discovery - var filterTypes = TestFilterTypeExtractor.ExtractTypesFromFilter(filter); - // Collect all tests first (like source generation mode does) var allTests = new List(); - await foreach (var test in DiscoverTestsStreamAsync(testSessionId, filterTypes, cancellationToken).ConfigureAwait(false)) + await foreach (var test in DiscoverTestsStreamAsync(testSessionId, cancellationToken).ConfigureAwait(false)) { allTests.Add(test); }