diff --git a/src/coreclr/jit/codegen.h b/src/coreclr/jit/codegen.h index 96954703599c93..8472dfe8c0233d 100644 --- a/src/coreclr/jit/codegen.h +++ b/src/coreclr/jit/codegen.h @@ -304,6 +304,7 @@ class CodeGen final : public CodeGenInterface #if defined(TARGET_WASM) void genJumpToThrowHlpBlk(SpecialCodeKind codeKind); + void genCodeForBinaryOverflow(GenTreeOp* node); #else void genJumpToThrowHlpBlk(emitJumpKind jumpKind, SpecialCodeKind codeKind, BasicBlock* failBlk = nullptr); #endif diff --git a/src/coreclr/jit/codegenwasm.cpp b/src/coreclr/jit/codegenwasm.cpp index 10af9be99b6097..566c7ca5bd57bb 100644 --- a/src/coreclr/jit/codegenwasm.cpp +++ b/src/coreclr/jit/codegenwasm.cpp @@ -1040,19 +1040,21 @@ void CodeGen::genFloatToFloatCast(GenTree* tree) // void CodeGen::genCodeForBinary(GenTreeOp* treeNode) { + if (treeNode->gtOverflow()) + { + genCodeForBinaryOverflow(treeNode); + return; + } + genConsumeOperands(treeNode); instruction ins; switch (PackOperAndType(treeNode->OperGet(), treeNode->TypeGet())) { case PackOperAndType(GT_ADD, TYP_INT): - if (treeNode->gtOverflow()) - NYI_WASM("Overflow checks"); ins = INS_i32_add; break; case PackOperAndType(GT_ADD, TYP_LONG): - if (treeNode->gtOverflow()) - NYI_WASM("Overflow checks"); ins = INS_i64_add; break; case PackOperAndType(GT_ADD, TYP_FLOAT): @@ -1063,13 +1065,9 @@ void CodeGen::genCodeForBinary(GenTreeOp* treeNode) break; case PackOperAndType(GT_SUB, TYP_INT): - if (treeNode->gtOverflow()) - NYI_WASM("Overflow checks"); ins = INS_i32_sub; break; case PackOperAndType(GT_SUB, TYP_LONG): - if (treeNode->gtOverflow()) - NYI_WASM("Overflow checks"); ins = INS_i64_sub; break; case PackOperAndType(GT_SUB, TYP_FLOAT): @@ -1080,13 +1078,9 @@ void CodeGen::genCodeForBinary(GenTreeOp* treeNode) break; case PackOperAndType(GT_MUL, TYP_INT): - if (treeNode->gtOverflow()) - NYI_WASM("Overflow checks"); ins = INS_i32_mul; break; case PackOperAndType(GT_MUL, TYP_LONG): - if (treeNode->gtOverflow()) - NYI_WASM("Overflow checks"); ins = INS_i64_mul; break; case PackOperAndType(GT_MUL, TYP_FLOAT): @@ -1127,6 +1121,152 @@ void CodeGen::genCodeForBinary(GenTreeOp* treeNode) WasmProduceReg(treeNode); } +//------------------------------------------------------------------------ +// genCodeForBinaryOverflow: Generate code for a binary arithmetic operator +// with overflow checking +// +// Arguments: +// treeNode - The binary operation for which we are generating code. +// +void CodeGen::genCodeForBinaryOverflow(GenTreeOp* treeNode) +{ + assert(treeNode->gtOverflow()); + assert(varTypeIsIntegral(treeNode->TypeGet())); + + // TODO-WASM-CQ: consider using helper calls for all these cases + + genConsumeOperands(treeNode); + + const bool is64BitOp = treeNode->TypeIs(TYP_LONG); + InternalRegs* regs = internalRegisters.GetAll(treeNode); + regNumber op1Reg = GetMultiUseOperandReg(treeNode->gtGetOp1()); + regNumber op2Reg = GetMultiUseOperandReg(treeNode->gtGetOp2()); + + switch (treeNode->OperGet()) + { + case GT_ADD: + { + // We require an internal register. + assert(regs->Count() == 1); + regNumber resultReg = regs->Extract(); + assert(WasmRegToType(resultReg) == TypeToWasmValueType(treeNode->TypeGet())); + + // Add and save the sum + GetEmitter()->emitIns(is64BitOp ? INS_i64_add : INS_i32_add); + GetEmitter()->emitIns_I(INS_local_set, emitActualTypeSize(treeNode), WasmRegToIndex(resultReg)); + // See if addends had the same sign. XOR leaves a non-negative result if they had the same sign. + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op1Reg)); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op2Reg)); + GetEmitter()->emitIns(is64BitOp ? INS_i64_xor : INS_i32_xor); + + // TODO-WASM-CQ: consider branchless alternative here (and for sub) + GetEmitter()->emitIns_I(is64BitOp ? INS_i64_const : INS_i32_const, emitActualTypeSize(treeNode), 0); + GetEmitter()->emitIns(is64BitOp ? INS_i64_ge_s : INS_i32_ge_s); + GetEmitter()->emitIns(INS_if); + { + // Operands have the same sign. If the sum has a different sign, then the add overflowed. + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(resultReg)); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op1Reg)); + GetEmitter()->emitIns(is64BitOp ? INS_i64_xor : INS_i32_xor); + GetEmitter()->emitIns_I(is64BitOp ? INS_i64_const : INS_i32_const, emitActualTypeSize(treeNode), 0); + GetEmitter()->emitIns(is64BitOp ? INS_i64_lt_s : INS_i32_lt_s); + genJumpToThrowHlpBlk(SCK_OVERFLOW); + } + GetEmitter()->emitIns(INS_end); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(resultReg)); + break; + } + + case GT_SUB: + { + // We require an internal register. + assert(regs->Count() == 1); + regNumber resultReg = regs->Extract(); + assert(WasmRegToType(resultReg) == TypeToWasmValueType(treeNode->TypeGet())); + + // Subtract and save the difference + GetEmitter()->emitIns(is64BitOp ? INS_i64_sub : INS_i32_sub); + GetEmitter()->emitIns_I(INS_local_set, emitActualTypeSize(treeNode), WasmRegToIndex(resultReg)); + // See if operands had a different sign. XOR leaves a negative result if they had different signs. + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op1Reg)); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op2Reg)); + GetEmitter()->emitIns(is64BitOp ? INS_i64_xor : INS_i32_xor); + GetEmitter()->emitIns_I(is64BitOp ? INS_i64_const : INS_i32_const, emitActualTypeSize(treeNode), 0); + GetEmitter()->emitIns(is64BitOp ? INS_i64_lt_s : INS_i32_lt_s); + GetEmitter()->emitIns(INS_if); + { + // Operands have different signs. If the difference has a different sign than op1, then the subtraction + // overflowed. + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(resultReg)); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op1Reg)); + GetEmitter()->emitIns(is64BitOp ? INS_i64_xor : INS_i32_xor); + GetEmitter()->emitIns_I(is64BitOp ? INS_i64_const : INS_i32_const, emitActualTypeSize(treeNode), 0); + GetEmitter()->emitIns(is64BitOp ? INS_i64_lt_s : INS_i32_lt_s); + genJumpToThrowHlpBlk(SCK_OVERFLOW); + } + GetEmitter()->emitIns(INS_end); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(resultReg)); + break; + } + + case GT_MUL: + { + if (is64BitOp) + { + assert(!"64 bit multiply with overflow should have been transformed into a helper call by morph"); + } + + // We require an I64 internal register + assert(regs->Count() == 1); + regNumber wideReg = regs->Extract(); + assert(WasmRegToType(wideReg) == WasmValueType::I64); + + // 32 bit multiply... check by doing a 64 bit multiply and then range-checking the result + const bool isUnsigned = treeNode->IsUnsigned(); + // Both operands are on the stack as I32. Drop the second, extend the first, then extend the second. + // + // TODO-WASM-CQ: consider transforming this to a (u)long multiply plus a checked cast, either in morph or + // lower. + GetEmitter()->emitIns(INS_drop); + GetEmitter()->emitIns(isUnsigned ? INS_i64_extend_u_i32 : INS_i64_extend_s_i32); + GetEmitter()->emitIns_I(INS_local_get, emitActualTypeSize(treeNode), WasmRegToIndex(op2Reg)); + GetEmitter()->emitIns(isUnsigned ? INS_i64_extend_u_i32 : INS_i64_extend_s_i32); + GetEmitter()->emitIns(INS_i64_mul); + + // Save the wide result, and then overflow check it. + GetEmitter()->emitIns_I(INS_local_tee, EA_8BYTE, WasmRegToIndex(wideReg)); + + if (isUnsigned) + { + // For unsigned multiply, we just need to check if the result is greater than UINT32_MAX. + GetEmitter()->emitIns_I(INS_i64_const, EA_8BYTE, UINT32_MAX); + GetEmitter()->emitIns(INS_i64_gt_u); + genJumpToThrowHlpBlk(SCK_OVERFLOW); + } + else + { + GetEmitter()->emitIns(INS_i64_extend32_s); + GetEmitter()->emitIns_I(INS_local_get, EA_8BYTE, WasmRegToIndex(wideReg)); + GetEmitter()->emitIns(INS_i64_ne); + genJumpToThrowHlpBlk(SCK_OVERFLOW); + } + + // If the check succeeds, the multiplication result is in range for a 32-bit int. + // We just need to return the low 32 bits of the result. + GetEmitter()->emitIns_I(INS_local_get, EA_8BYTE, WasmRegToIndex(wideReg)); + GetEmitter()->emitIns(INS_i32_wrap_i64); + + break; + } + + default: + unreached(); + break; + } + + WasmProduceReg(treeNode); +} + //------------------------------------------------------------------------ // genCodeForDivMod: Generate code for a division or modulus operator // diff --git a/src/coreclr/jit/lowerwasm.cpp b/src/coreclr/jit/lowerwasm.cpp index 35a40c48ee9a4c..e241e9a3a02700 100644 --- a/src/coreclr/jit/lowerwasm.cpp +++ b/src/coreclr/jit/lowerwasm.cpp @@ -165,6 +165,13 @@ GenTree* Lowering::LowerJTrue(GenTreeOp* jtrue) GenTree* Lowering::LowerBinaryArithmetic(GenTreeOp* binOp) { ContainCheckBinary(binOp); + + if (binOp->gtOverflow()) + { + binOp->gtGetOp1()->gtLIRFlags |= LIR::Flags::MultiplyUsed; + binOp->gtGetOp2()->gtLIRFlags |= LIR::Flags::MultiplyUsed; + } + return binOp->gtNext; } diff --git a/src/coreclr/jit/morph.cpp b/src/coreclr/jit/morph.cpp index 0018e4442855f8..d6efc7ca7f89dc 100644 --- a/src/coreclr/jit/morph.cpp +++ b/src/coreclr/jit/morph.cpp @@ -7131,6 +7131,22 @@ GenTree* Compiler::fgMorphSmpOp(GenTree* tree, MorphAddrContext* mac, bool* optA } } #endif // !defined(TARGET_64BIT) && !defined(TARGET_WASM) + +#if defined(TARGET_WASM) + if (tree->gtOverflow()) + { + // For long multiply with overflow, call the helper. + if (tree->TypeIs(TYP_LONG)) + { + helper = tree->IsUnsigned() ? CORINFO_HELP_ULMUL_OVF : CORINFO_HELP_LMUL_OVF; + goto USE_HELPER_FOR_ARITH; + } + else + { + // TODO-WASM_CQ: Transform to a long multiply and then a checked cast? + } + } +#endif break; case GT_ARR_LENGTH: diff --git a/src/coreclr/jit/regallocwasm.cpp b/src/coreclr/jit/regallocwasm.cpp index e03db3a6910585..88efaee43d7284 100644 --- a/src/coreclr/jit/regallocwasm.cpp +++ b/src/coreclr/jit/regallocwasm.cpp @@ -234,7 +234,21 @@ regNumber WasmRegAlloc::AllocateTemporaryRegister(var_types type) regNumber WasmRegAlloc::ReleaseTemporaryRegister(var_types type) { WasmValueType wasmType = TypeToWasmValueType(type); - unsigned index = m_temporaryRegs[static_cast(wasmType)].Pop(); + return ReleaseTemporaryRegister(wasmType); +} + +//------------------------------------------------------------------------ +// ReleaseTemporaryRegister: Release the most recently allocated temporary register. +// +// Arguments: +// wasmType - The register's wasm type +// +// Return Value: +// The released register. +// +regNumber WasmRegAlloc::ReleaseTemporaryRegister(WasmValueType wasmType) +{ + unsigned index = m_temporaryRegs[static_cast(wasmType)].Pop(); return MakeWasmReg(index, wasmType); } @@ -310,6 +324,12 @@ void WasmRegAlloc::CollectReferencesForNode(GenTree* node) CollectReferencesForCast(node->AsOp()); break; + case GT_ADD: + case GT_SUB: + case GT_MUL: + CollectReferencesForBinop(node->AsOp()); + break; + default: assert(!node->OperIsLocalStore()); break; @@ -363,6 +383,40 @@ void WasmRegAlloc::CollectReferencesForCast(GenTreeOp* castNode) ConsumeTemporaryRegForOperand(castNode->gtGetOp1() DEBUGARG("cast overflow check")); } +//------------------------------------------------------------------------ +// CollectReferencesForBinop: Collect virtual register references for a binary operation. +// +// Consumes temporary registers for a binary operation. +// +// Arguments: +// binopNode - The binary operation node +// +void WasmRegAlloc::CollectReferencesForBinop(GenTreeOp* binopNode) +{ + regNumber internalReg = REG_NA; + if (binopNode->gtOverflow()) + { + if (binopNode->OperIs(GT_ADD) || binopNode->OperIs(GT_SUB)) + { + internalReg = RequestInternalRegister(binopNode, binopNode->TypeGet()); + } + else if (binopNode->OperIs(GT_MUL)) + { + assert(binopNode->TypeIs(TYP_INT)); + internalReg = RequestInternalRegister(binopNode, TYP_LONG); + } + } + + if (internalReg != REG_NA) + { + regNumber releasedReg = ReleaseTemporaryRegister(WasmRegToType(internalReg)); + assert(releasedReg == internalReg); + } + + ConsumeTemporaryRegForOperand(binopNode->gtGetOp2() DEBUGARG("binop overflow check")); + ConsumeTemporaryRegForOperand(binopNode->gtGetOp1() DEBUGARG("binop overflow check")); +} + //------------------------------------------------------------------------ // CollectReferencesForLclVar: Collect virtual register references for a LCL_VAR. // @@ -514,6 +568,25 @@ void WasmRegAlloc::ConsumeTemporaryRegForOperand(GenTree* operand DEBUGARG(const JITDUMP("Consumed a temporary reg for [%06u]: %s\n", Compiler::dspTreeID(operand), reason); } +//------------------------------------------------------------------------ +// RequestInternalRegister: request an internal register for a node with specific type. +// +// To be later assigned a physical register. +// +// Arguments: +// node - node whose codegen will need an internal register +// type - type of the internal register +// +// Returns: +// reg number of internal register. +// +regNumber WasmRegAlloc::RequestInternalRegister(GenTree* node, var_types type) +{ + regNumber reg = AllocateTemporaryRegister(genActualType(type)); + m_codeGen->internalRegisters.Add(node, reg); + return reg; +} + //------------------------------------------------------------------------ // ResolveReferences: Translate virtual registers to physical ones (WASM locals). // diff --git a/src/coreclr/jit/regallocwasm.h b/src/coreclr/jit/regallocwasm.h index 4b1073590a983b..a24f07df1fa73d 100644 --- a/src/coreclr/jit/regallocwasm.h +++ b/src/coreclr/jit/regallocwasm.h @@ -115,18 +115,21 @@ class WasmRegAlloc : public RegAllocInterface regNumber AllocateVirtualRegister(WasmValueType type); regNumber AllocateTemporaryRegister(var_types type); regNumber ReleaseTemporaryRegister(var_types type); - - void CollectReferences(); - void CollectReferencesForBlock(BasicBlock* block); - void CollectReferencesForNode(GenTree* node); - void CollectReferencesForDivMod(GenTreeOp* divModNode); - void CollectReferencesForCall(GenTreeCall* callNode); - void CollectReferencesForCast(GenTreeOp* castNode); - void CollectReferencesForLclVar(GenTreeLclVar* lclVar); - void RewriteLocalStackStore(GenTreeLclVarCommon* node); - void CollectReference(GenTree* node); - void RequestTemporaryRegisterForMultiplyUsedNode(GenTree* node); - void ConsumeTemporaryRegForOperand(GenTree* operand DEBUGARG(const char* reason)); + regNumber ReleaseTemporaryRegister(WasmValueType wasmType); + + void CollectReferences(); + void CollectReferencesForBlock(BasicBlock* block); + void CollectReferencesForNode(GenTree* node); + void CollectReferencesForDivMod(GenTreeOp* divModNode); + void CollectReferencesForCall(GenTreeCall* callNode); + void CollectReferencesForCast(GenTreeOp* castNode); + void CollectReferencesForBinop(GenTreeOp* binOpNode); + void CollectReferencesForLclVar(GenTreeLclVar* lclVar); + void RewriteLocalStackStore(GenTreeLclVarCommon* node); + void CollectReference(GenTree* node); + void RequestTemporaryRegisterForMultiplyUsedNode(GenTree* node); + regNumber RequestInternalRegister(GenTree* node, var_types type); + void ConsumeTemporaryRegForOperand(GenTree* operand DEBUGARG(const char* reason)); void ResolveReferences();