diff --git a/src/coreclr/jit/async.cpp b/src/coreclr/jit/async.cpp index 9f79a026b6c5ab..d803ee48a31dec 100644 --- a/src/coreclr/jit/async.cpp +++ b/src/coreclr/jit/async.cpp @@ -1035,43 +1035,39 @@ PhaseStatus Compiler::TransformAsync() PhaseStatus AsyncTransformation::Run() { PhaseStatus result = PhaseStatus::MODIFIED_NOTHING; - ArrayStack worklist(m_compiler->getAllocator(CMK_Async)); + ArrayStack blocksWithNormalAwaits(m_compiler->getAllocator(CMK_Async)); + ArrayStack blocksWithTailAwaits(m_compiler->getAllocator(CMK_Async)); + int numNormalAwaits = 0; + int numTailAwaits = 0; + FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits, &numNormalAwaits, &numTailAwaits); - // First find all basic blocks with awaits in them. We'll have to track - // liveness in these basic blocks, so it does not help to record the calls - // ahead of time. - BasicBlock* nextBlock; - for (BasicBlock* block = m_compiler->fgFirstBB; block != nullptr; block = nextBlock) + if (numNormalAwaits + numTailAwaits > 1) { - bool hasAwait = false; - nextBlock = block->Next(); - for (GenTree* tree : LIR::AsRange(block)) - { - if (!tree->IsCall() || !tree->AsCall()->IsAsync() || tree->AsCall()->IsTailCall()) - { - continue; - } - - if (tree->AsCall()->GetAsyncInfo().IsTailAwait) - { - TransformTailAwait(block, tree->AsCall(), &nextBlock); - result = PhaseStatus::MODIFIED_EVERYTHING; - break; - } - - JITDUMP(FMT_BB " contains await(s)\n", block->bbNum); - hasAwait = true; - } + CreateSharedReturnBB(); + } - if (hasAwait) + // Transform all tail awaits first. They will not require running all of + // our analyses. + if (numTailAwaits > 0) + { + JITDUMP("Found %d tail awaits in %d blocks\n", numTailAwaits, blocksWithTailAwaits.Height()); + TransformTailAwaits(blocksWithTailAwaits); + if (numNormalAwaits > 0) { - worklist.Push(block); + blocksWithNormalAwaits.Reset(); + blocksWithTailAwaits.Reset(); + numNormalAwaits = 0; + numTailAwaits = 0; + FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits, &numNormalAwaits, &numTailAwaits); + assert((numTailAwaits == 0) && (blocksWithTailAwaits.Height() == 0)); } + + result = PhaseStatus::MODIFIED_EVERYTHING; } - JITDUMP("Found %d blocks with awaits\n", worklist.Height()); + JITDUMP("Found %d awaits in %d blocks\n", numNormalAwaits, blocksWithNormalAwaits.Height()); - if (worklist.Height() <= 0) + if (numNormalAwaits <= 0) { return result; } @@ -1083,9 +1079,6 @@ PhaseStatus AsyncTransformation::Run() m_asyncInfo = m_compiler->eeGetAsyncInfo(); - // Create the shared return BB now to put it in the right place in the block order. - GetSharedReturnBB(); - // Compute liveness to be used for determining what must be captured on // suspension. if (m_compiler->m_dfsTree == nullptr) @@ -1107,11 +1100,11 @@ PhaseStatus AsyncTransformation::Run() // async calls are additional live variables that must be spilled. jitstd::vector defs(m_compiler->getAllocator(CMK_Async)); - for (int i = 0; i < worklist.Height(); i++) + for (int i = 0; i < blocksWithNormalAwaits.Height(); i++) { assert(defs.size() == 0); - BasicBlock* block = worklist.Bottom(i); + BasicBlock* block = blocksWithNormalAwaits.Bottom(i); liveness.StartBlock(block); bool any; @@ -1193,6 +1186,88 @@ PhaseStatus AsyncTransformation::Run() return PhaseStatus::MODIFIED_EVERYTHING; } +//------------------------------------------------------------------------ +// AsyncTransformation::FindAwaits: +// Find the blocks that have awaits in them and do some accounting of how +// many awaits there are. +// +// Parameters: +// blocksWithNormalAwaits - [out] Blocks with normal awaits are pushed onto this stack +// blocksWithTailAwaits - [out] Blocks with tail awaits are pushed onto this stack +// numNormalAwaits - [out] Number of normal awaits found +// numTailAwaits - [out] Number of tail awaits found +// +void AsyncTransformation::FindAwaits(ArrayStack& blocksWithNormalAwaits, + ArrayStack& blocksWithTailAwaits, + int* numNormalAwaits, + int* numTailAwaits) +{ + for (BasicBlock* block : m_compiler->Blocks()) + { + bool hasNormalAwait = false; + bool hasTailAwait = false; + for (GenTree* tree : LIR::AsRange(block)) + { + if (!tree->IsCall() || !tree->AsCall()->IsAsync() || tree->AsCall()->IsTailCall()) + { + continue; + } + + if (tree->AsCall()->GetAsyncInfo().IsTailAwait) + { + hasTailAwait = true; + (*numTailAwaits)++; + } + else + { + hasNormalAwait = true; + (*numNormalAwaits)++; + } + } + + if (hasNormalAwait) + { + blocksWithNormalAwaits.Push(block); + } + + if (hasTailAwait) + { + blocksWithTailAwaits.Push(block); + } + } +} + +//------------------------------------------------------------------------ +// AsyncTransformation::TransformTailAwaits: +// Transform all tail awaits in the specified blocks. +// +// Parameters: +// blocksWithTailAwaits - Blocks containing tail awaits +// +void AsyncTransformation::TransformTailAwaits(ArrayStack& blocksWithTailAwaits) +{ + for (int i = 0; i < blocksWithTailAwaits.Height(); i++) + { + BasicBlock* block = blocksWithTailAwaits.Bottom(i); + + bool any; + do + { + any = false; + for (GenTree* tree : LIR::AsRange(block)) + { + if (tree->IsCall() && tree->AsCall()->IsAsync() && !tree->AsCall()->IsTailCall() && + tree->AsCall()->GetAsyncInfo().IsTailAwait) + { + TransformTailAwait(block, tree->AsCall(), &block); + any = true; + break; + } + } + } while (any); + } +} + //------------------------------------------------------------------------ // AsyncTransformation::TransformTailAwait: // Transform an await that was marked as a tail await. @@ -1227,7 +1302,7 @@ void AsyncTransformation::TransformTailAwait(BasicBlock* block, GenTreeCall* cal // BasicBlock* AsyncTransformation::CreateTailAwaitSuspension(BasicBlock* block, GenTreeCall* call) { - BasicBlock* sharedReturnBB = GetSharedReturnBB(); + BasicBlock* sharedReturnBB = m_sharedReturnBB; if (m_lastSuspensionBB == nullptr) { @@ -2821,21 +2896,11 @@ unsigned AsyncTransformation::GetExceptionVar() } //------------------------------------------------------------------------ -// AsyncTransformation::GetSharedReturnBB: -// Create the shared return BB, if one is needed. +// AsyncTransformation::CreateSharedReturnBB: +// Create the shared return BB. // -// Returns: -// Basic block or nullptr. -// -BasicBlock* AsyncTransformation::GetSharedReturnBB() +void AsyncTransformation::CreateSharedReturnBB() { -#ifdef JIT32_GCENCODER - if (m_sharedReturnBB != nullptr) - { - return m_sharedReturnBB; - } - - // Due to a hard cap on epilogs we need a shared return here. m_sharedReturnBB = m_compiler->fgNewBBafter(BBJ_RETURN, m_compiler->fgLastBBInMainFunction(), false); m_sharedReturnBB->bbSetRunRarely(); m_sharedReturnBB->clearTryIndex(); @@ -2855,10 +2920,6 @@ BasicBlock* AsyncTransformation::GetSharedReturnBB() JITDUMP("Created shared return BB " FMT_BB "\n", m_sharedReturnBB->bbNum); DISPRANGE(LIR::AsRange(m_sharedReturnBB)); - return m_sharedReturnBB; -#else - return nullptr; -#endif } //------------------------------------------------------------------------ diff --git a/src/coreclr/jit/async.h b/src/coreclr/jit/async.h index b07348a0e42141..15fd13dc7e4746 100644 --- a/src/coreclr/jit/async.h +++ b/src/coreclr/jit/async.h @@ -72,6 +72,12 @@ class AsyncTransformation BasicBlock* m_lastResumptionBB = nullptr; BasicBlock* m_sharedReturnBB = nullptr; + void FindAwaits(ArrayStack& blocksWithNormalAwaits, + ArrayStack& blocksWithTailAwaits, + int* numNormalAwaits, + int* numTailAwaits); + + void TransformTailAwaits(ArrayStack& blocksWithTailAwaits); void TransformTailAwait(BasicBlock* block, GenTreeCall* call, BasicBlock** remainder); BasicBlock* CreateTailAwaitSuspension(BasicBlock* block, GenTreeCall* call); @@ -139,13 +145,13 @@ class AsyncTransformation var_types storeType, GenTreeFlags indirFlags = GTF_IND_NONFAULTING); - void CreateDebugInfoForSuspensionPoint(const ContinuationLayout& layout); - unsigned GetReturnedContinuationVar(); - unsigned GetNewContinuationVar(); - unsigned GetResultBaseVar(); - unsigned GetExceptionVar(); - BasicBlock* GetSharedReturnBB(); + void CreateDebugInfoForSuspensionPoint(const ContinuationLayout& layout); + unsigned GetReturnedContinuationVar(); + unsigned GetNewContinuationVar(); + unsigned GetResultBaseVar(); + unsigned GetExceptionVar(); + void CreateSharedReturnBB(); void CreateResumptionSwitch(); public: