Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 143 additions & 143 deletions packages/opencode/src/mcp/oauth-callback.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,177 +56,177 @@ interface PendingAuth {
timeout: ReturnType<typeof setTimeout>
}

export namespace McpOAuthCallback {
let server: ReturnType<typeof createServer> | undefined
const pendingAuths = new Map<string, PendingAuth>()
// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
// find the right entry in pendingAuths (which is keyed by oauthState).
const mcpNameToState = new Map<string, string>()

const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes

function cleanupStateIndex(oauthState: string) {
for (const [name, state] of mcpNameToState) {
if (state === oauthState) {
mcpNameToState.delete(name)
break
}
let server: ReturnType<typeof createServer> | undefined
const pendingAuths = new Map<string, PendingAuth>()
// Reverse index: mcpName → oauthState, so cancelPending(mcpName) can
// find the right entry in pendingAuths (which is keyed by oauthState).
const mcpNameToState = new Map<string, string>()

const CALLBACK_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes

function cleanupStateIndex(oauthState: string) {
for (const [name, state] of mcpNameToState) {
if (state === oauthState) {
mcpNameToState.delete(name)
break
}
}
}

function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
const url = new URL(req.url || "/", `http://localhost:${currentPort}`)
function handleRequest(req: import("http").IncomingMessage, res: import("http").ServerResponse) {
const url = new URL(req.url || "/", `http://localhost:${currentPort}`)

if (url.pathname !== currentPath) {
res.writeHead(404)
res.end("Not found")
return
}
if (url.pathname !== currentPath) {
res.writeHead(404)
res.end("Not found")
return
}

const code = url.searchParams.get("code")
const state = url.searchParams.get("state")
const error = url.searchParams.get("error")
const errorDescription = url.searchParams.get("error_description")
const code = url.searchParams.get("code")
const state = url.searchParams.get("state")
const error = url.searchParams.get("error")
const errorDescription = url.searchParams.get("error_description")

log.info("received oauth callback", { hasCode: !!code, state, error })
log.info("received oauth callback", { hasCode: !!code, state, error })

// Enforce state parameter presence
if (!state) {
const errorMsg = "Missing required state parameter - potential CSRF attack"
log.error("oauth callback missing state parameter", { url: url.toString() })
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}
// Enforce state parameter presence
if (!state) {
const errorMsg = "Missing required state parameter - potential CSRF attack"
log.error("oauth callback missing state parameter", { url: url.toString() })
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}

if (error) {
const errorMsg = errorDescription || error
if (pendingAuths.has(state)) {
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
cleanupStateIndex(state)
pending.reject(new Error(errorMsg))
}
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
if (error) {
const errorMsg = errorDescription || error
if (pendingAuths.has(state)) {
const pending = pendingAuths.get(state)!
clearTimeout(pending.timeout)
pendingAuths.delete(state)
cleanupStateIndex(state)
pending.reject(new Error(errorMsg))
}
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}

if (!code) {
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR("No authorization code provided"))
return
}
if (!code) {
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR("No authorization code provided"))
return
}

// Validate state parameter
if (!pendingAuths.has(state)) {
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}
// Validate state parameter
if (!pendingAuths.has(state)) {
const errorMsg = "Invalid or expired state parameter - potential CSRF attack"
log.error("oauth callback with invalid state", { state, pendingStates: Array.from(pendingAuths.keys()) })
res.writeHead(400, { "Content-Type": "text/html" })
res.end(HTML_ERROR(errorMsg))
return
}

const pending = pendingAuths.get(state)!
const pending = pendingAuths.get(state)!

clearTimeout(pending.timeout)
pendingAuths.delete(state)
cleanupStateIndex(state)
pending.resolve(code)
clearTimeout(pending.timeout)
pendingAuths.delete(state)
cleanupStateIndex(state)
pending.resolve(code)

res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_SUCCESS)
}
res.writeHead(200, { "Content-Type": "text/html" })
res.end(HTML_SUCCESS)
}

export async function ensureRunning(redirectUri?: string): Promise<void> {
// Parse the redirect URI to get port and path (uses defaults if not provided)
const { port, path } = parseRedirectUri(redirectUri)
export async function ensureRunning(redirectUri?: string): Promise<void> {
// Parse the redirect URI to get port and path (uses defaults if not provided)
const { port, path } = parseRedirectUri(redirectUri)

// If server is running on a different port/path, stop it first
if (server && (currentPort !== port || currentPath !== path)) {
log.info("stopping oauth callback server to reconfigure", { oldPort: currentPort, newPort: port })
await stop()
}
// If server is running on a different port/path, stop it first
if (server && (currentPort !== port || currentPath !== path)) {
log.info("stopping oauth callback server to reconfigure", { oldPort: currentPort, newPort: port })
await stop()
}

if (server) return
if (server) return

const running = await isPortInUse(port)
if (running) {
log.info("oauth callback server already running on another instance", { port })
return
}
const running = await isPortInUse(port)
if (running) {
log.info("oauth callback server already running on another instance", { port })
return
}

currentPort = port
currentPath = path
currentPort = port
currentPath = path

server = createServer(handleRequest)
await new Promise<void>((resolve, reject) => {
server!.listen(currentPort, () => {
log.info("oauth callback server started", { port: currentPort, path: currentPath })
resolve()
})
server!.on("error", reject)
server = createServer(handleRequest)
await new Promise<void>((resolve, reject) => {
server!.listen(currentPort, () => {
log.info("oauth callback server started", { port: currentPort, path: currentPath })
resolve()
})
}
server!.on("error", reject)
})
}

export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
if (mcpName) mcpNameToState.set(mcpName, oauthState)
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
if (pendingAuths.has(oauthState)) {
pendingAuths.delete(oauthState)
if (mcpName) mcpNameToState.delete(mcpName)
reject(new Error("OAuth callback timeout - authorization took too long"))
}
}, CALLBACK_TIMEOUT_MS)

pendingAuths.set(oauthState, { resolve, reject, timeout })
})
}
export function waitForCallback(oauthState: string, mcpName?: string): Promise<string> {
if (mcpName) mcpNameToState.set(mcpName, oauthState)
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
if (pendingAuths.has(oauthState)) {
pendingAuths.delete(oauthState)
if (mcpName) mcpNameToState.delete(mcpName)
reject(new Error("OAuth callback timeout - authorization took too long"))
}
}, CALLBACK_TIMEOUT_MS)

export function cancelPending(mcpName: string): void {
// Look up the oauthState for this mcpName via the reverse index
const oauthState = mcpNameToState.get(mcpName)
const key = oauthState ?? mcpName
const pending = pendingAuths.get(key)
if (pending) {
clearTimeout(pending.timeout)
pendingAuths.delete(key)
mcpNameToState.delete(mcpName)
pending.reject(new Error("Authorization cancelled"))
}
}
pendingAuths.set(oauthState, { resolve, reject, timeout })
})
}

export async function isPortInUse(port: number = OAUTH_CALLBACK_PORT): Promise<boolean> {
return new Promise((resolve) => {
const socket = createConnection(port, "127.0.0.1")
socket.on("connect", () => {
socket.destroy()
resolve(true)
})
socket.on("error", () => {
resolve(false)
})
})
export function cancelPending(mcpName: string): void {
// Look up the oauthState for this mcpName via the reverse index
const oauthState = mcpNameToState.get(mcpName)
const key = oauthState ?? mcpName
const pending = pendingAuths.get(key)
if (pending) {
clearTimeout(pending.timeout)
pendingAuths.delete(key)
mcpNameToState.delete(mcpName)
pending.reject(new Error("Authorization cancelled"))
}
}

export async function stop(): Promise<void> {
if (server) {
await new Promise<void>((resolve) => server!.close(() => resolve()))
server = undefined
log.info("oauth callback server stopped")
}
export async function isPortInUse(port: number = OAUTH_CALLBACK_PORT): Promise<boolean> {
return new Promise((resolve) => {
const socket = createConnection(port, "127.0.0.1")
socket.on("connect", () => {
socket.destroy()
resolve(true)
})
socket.on("error", () => {
resolve(false)
})
})
}

for (const [_name, pending] of pendingAuths) {
clearTimeout(pending.timeout)
pending.reject(new Error("OAuth callback server stopped"))
}
pendingAuths.clear()
mcpNameToState.clear()
export async function stop(): Promise<void> {
if (server) {
await new Promise<void>((resolve) => server!.close(() => resolve()))
server = undefined
log.info("oauth callback server stopped")
}

export function isRunning(): boolean {
return server !== undefined
for (const [_name, pending] of pendingAuths) {
clearTimeout(pending.timeout)
pending.reject(new Error("OAuth callback server stopped"))
}
pendingAuths.clear()
mcpNameToState.clear()
}

export function isRunning(): boolean {
return server !== undefined
}

export * as McpOAuthCallback from "./oauth-callback"
Loading