Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 113 additions & 52 deletions src/coreclr/jit/async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,43 +1035,39 @@ PhaseStatus Compiler::TransformAsync()
PhaseStatus AsyncTransformation::Run()
{
PhaseStatus result = PhaseStatus::MODIFIED_NOTHING;
ArrayStack<BasicBlock*> worklist(m_compiler->getAllocator(CMK_Async));
ArrayStack<BasicBlock*> blocksWithNormalAwaits(m_compiler->getAllocator(CMK_Async));
ArrayStack<BasicBlock*> blocksWithTailAwaits(m_compiler->getAllocator(CMK_Async));
int numNormalAwaits = 0;
int numTailAwaits = 0;
FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits, &numNormalAwaits, &numTailAwaits);

// First find all basic blocks with awaits in them. We'll have to track
// liveness in these basic blocks, so it does not help to record the calls
// ahead of time.
BasicBlock* nextBlock;
for (BasicBlock* block = m_compiler->fgFirstBB; block != nullptr; block = nextBlock)
if (numNormalAwaits + numTailAwaits > 1)
{
bool hasAwait = false;
nextBlock = block->Next();
for (GenTree* tree : LIR::AsRange(block))
{
if (!tree->IsCall() || !tree->AsCall()->IsAsync() || tree->AsCall()->IsTailCall())
{
continue;
}

if (tree->AsCall()->GetAsyncInfo().IsTailAwait)
{
TransformTailAwait(block, tree->AsCall(), &nextBlock);
result = PhaseStatus::MODIFIED_EVERYTHING;
break;
}

JITDUMP(FMT_BB " contains await(s)\n", block->bbNum);
hasAwait = true;
}
CreateSharedReturnBB();
}

if (hasAwait)
// Transform all tail awaits first. They will not require running all of
// our analyses.
if (numTailAwaits > 0)
{
JITDUMP("Found %d tail awaits in %d blocks\n", numTailAwaits, blocksWithTailAwaits.Height());
TransformTailAwaits(blocksWithTailAwaits);
if (numNormalAwaits > 0)
{
worklist.Push(block);
blocksWithNormalAwaits.Reset();
blocksWithTailAwaits.Reset();
numNormalAwaits = 0;
numTailAwaits = 0;
FindAwaits(blocksWithNormalAwaits, blocksWithTailAwaits, &numNormalAwaits, &numTailAwaits);
assert((numTailAwaits == 0) && (blocksWithTailAwaits.Height() == 0));
}

result = PhaseStatus::MODIFIED_EVERYTHING;
}

JITDUMP("Found %d blocks with awaits\n", worklist.Height());
JITDUMP("Found %d awaits in %d blocks\n", numNormalAwaits, blocksWithNormalAwaits.Height());

if (worklist.Height() <= 0)
if (numNormalAwaits <= 0)
{
return result;
}
Expand All @@ -1083,9 +1079,6 @@ PhaseStatus AsyncTransformation::Run()

m_asyncInfo = m_compiler->eeGetAsyncInfo();

// Create the shared return BB now to put it in the right place in the block order.
GetSharedReturnBB();

// Compute liveness to be used for determining what must be captured on
// suspension.
if (m_compiler->m_dfsTree == nullptr)
Expand All @@ -1107,11 +1100,11 @@ PhaseStatus AsyncTransformation::Run()
// async calls are additional live variables that must be spilled.
jitstd::vector<GenTree*> defs(m_compiler->getAllocator(CMK_Async));

for (int i = 0; i < worklist.Height(); i++)
for (int i = 0; i < blocksWithNormalAwaits.Height(); i++)
{
assert(defs.size() == 0);

BasicBlock* block = worklist.Bottom(i);
BasicBlock* block = blocksWithNormalAwaits.Bottom(i);
liveness.StartBlock(block);

bool any;
Expand Down Expand Up @@ -1193,6 +1186,88 @@ PhaseStatus AsyncTransformation::Run()
return PhaseStatus::MODIFIED_EVERYTHING;
}

//------------------------------------------------------------------------
// AsyncTransformation::FindAwaits:
// Find the blocks that have awaits in them and do some accounting of how
// many awaits there are.
//
// Parameters:
// blocksWithNormalAwaits - [out] Blocks with normal awaits are pushed onto this stack
// blocksWithTailAwaits - [out] Blocks with tail awaits are pushed onto this stack
// numNormalAwaits - [out] Number of normal awaits found
// numTailAwaits - [out] Number of tail awaits found
//
void AsyncTransformation::FindAwaits(ArrayStack<BasicBlock*>& blocksWithNormalAwaits,
ArrayStack<BasicBlock*>& blocksWithTailAwaits,
int* numNormalAwaits,
int* numTailAwaits)
{
for (BasicBlock* block : m_compiler->Blocks())
{
bool hasNormalAwait = false;
bool hasTailAwait = false;
for (GenTree* tree : LIR::AsRange(block))
{
if (!tree->IsCall() || !tree->AsCall()->IsAsync() || tree->AsCall()->IsTailCall())
{
continue;
}

if (tree->AsCall()->GetAsyncInfo().IsTailAwait)
{
hasTailAwait = true;
(*numTailAwaits)++;
}
else
{
hasNormalAwait = true;
(*numNormalAwaits)++;
}
}

if (hasNormalAwait)
{
blocksWithNormalAwaits.Push(block);
}

if (hasTailAwait)
{
blocksWithTailAwaits.Push(block);
}
}
}

//------------------------------------------------------------------------
// AsyncTransformation::TransformTailAwaits:
// Transform all tail awaits in the specified blocks.
//
// Parameters:
// blocksWithTailAwaits - Blocks containing tail awaits
//
void AsyncTransformation::TransformTailAwaits(ArrayStack<BasicBlock*>& blocksWithTailAwaits)
{
for (int i = 0; i < blocksWithTailAwaits.Height(); i++)
{
BasicBlock* block = blocksWithTailAwaits.Bottom(i);

bool any;
do
{
any = false;
for (GenTree* tree : LIR::AsRange(block))
{
if (tree->IsCall() && tree->AsCall()->IsAsync() && !tree->AsCall()->IsTailCall() &&
tree->AsCall()->GetAsyncInfo().IsTailAwait)
{
TransformTailAwait(block, tree->AsCall(), &block);
any = true;
break;
}
}
} while (any);
}
}

//------------------------------------------------------------------------
// AsyncTransformation::TransformTailAwait:
// Transform an await that was marked as a tail await.
Expand Down Expand Up @@ -1227,7 +1302,7 @@ void AsyncTransformation::TransformTailAwait(BasicBlock* block, GenTreeCall* cal
//
BasicBlock* AsyncTransformation::CreateTailAwaitSuspension(BasicBlock* block, GenTreeCall* call)
{
BasicBlock* sharedReturnBB = GetSharedReturnBB();
BasicBlock* sharedReturnBB = m_sharedReturnBB;

if (m_lastSuspensionBB == nullptr)
{
Expand Down Expand Up @@ -2821,21 +2896,11 @@ unsigned AsyncTransformation::GetExceptionVar()
}

//------------------------------------------------------------------------
// AsyncTransformation::GetSharedReturnBB:
// Create the shared return BB, if one is needed.
// AsyncTransformation::CreateSharedReturnBB:
// Create the shared return BB.
//
// Returns:
// Basic block or nullptr.
//
BasicBlock* AsyncTransformation::GetSharedReturnBB()
void AsyncTransformation::CreateSharedReturnBB()
{
#ifdef JIT32_GCENCODER
if (m_sharedReturnBB != nullptr)
{
return m_sharedReturnBB;
}

// Due to a hard cap on epilogs we need a shared return here.
m_sharedReturnBB = m_compiler->fgNewBBafter(BBJ_RETURN, m_compiler->fgLastBBInMainFunction(), false);
m_sharedReturnBB->bbSetRunRarely();
m_sharedReturnBB->clearTryIndex();
Expand All @@ -2855,10 +2920,6 @@ BasicBlock* AsyncTransformation::GetSharedReturnBB()
JITDUMP("Created shared return BB " FMT_BB "\n", m_sharedReturnBB->bbNum);

DISPRANGE(LIR::AsRange(m_sharedReturnBB));
return m_sharedReturnBB;
#else
return nullptr;
#endif
}

//------------------------------------------------------------------------
Expand Down
18 changes: 12 additions & 6 deletions src/coreclr/jit/async.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class AsyncTransformation
BasicBlock* m_lastResumptionBB = nullptr;
BasicBlock* m_sharedReturnBB = nullptr;

void FindAwaits(ArrayStack<BasicBlock*>& blocksWithNormalAwaits,
ArrayStack<BasicBlock*>& blocksWithTailAwaits,
int* numNormalAwaits,
int* numTailAwaits);

void TransformTailAwaits(ArrayStack<BasicBlock*>& blocksWithTailAwaits);
void TransformTailAwait(BasicBlock* block, GenTreeCall* call, BasicBlock** remainder);
BasicBlock* CreateTailAwaitSuspension(BasicBlock* block, GenTreeCall* call);

Expand Down Expand Up @@ -139,13 +145,13 @@ class AsyncTransformation
var_types storeType,
GenTreeFlags indirFlags = GTF_IND_NONFAULTING);

void CreateDebugInfoForSuspensionPoint(const ContinuationLayout& layout);
unsigned GetReturnedContinuationVar();
unsigned GetNewContinuationVar();
unsigned GetResultBaseVar();
unsigned GetExceptionVar();
BasicBlock* GetSharedReturnBB();
void CreateDebugInfoForSuspensionPoint(const ContinuationLayout& layout);
unsigned GetReturnedContinuationVar();
unsigned GetNewContinuationVar();
unsigned GetResultBaseVar();
unsigned GetExceptionVar();

void CreateSharedReturnBB();
void CreateResumptionSwitch();

public:
Expand Down
Loading