Skip to content

Commit e79014e

Browse files
rekram1-nodedemostanis
authored andcommitted
feat: integrate support for multi step auth flows for providers that require additional questions (anomalyco#18035)
1 parent 591e948 commit e79014e

File tree

11 files changed

+344
-8
lines changed

11 files changed

+344
-8
lines changed

packages/opencode/src/cli/cmd/providers.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ async function handlePluginAuth(plugin: { auth: PluginAuth }, provider: string,
4646
const inputs: Record<string, string> = {}
4747
if (method.prompts) {
4848
for (const prompt of method.prompts) {
49-
if (prompt.condition && !prompt.condition(inputs)) {
50-
continue
49+
if (prompt.when) {
50+
const value = inputs[prompt.when.key]
51+
if (value === undefined) continue
52+
const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
53+
if (!matches) continue
5154
}
55+
if (prompt.condition && !prompt.condition(inputs)) continue
5256
if (prompt.type === "select") {
5357
const value = await prompts.select({
5458
message: prompt.message,

packages/opencode/src/cli/cmd/tui/component/dialog-provider.tsx

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { DialogPrompt } from "../ui/dialog-prompt"
88
import { Link } from "../ui/link"
99
import { useTheme } from "../context/theme"
1010
import { TextAttributes } from "@opentui/core"
11-
import type { ProviderAuthAuthorization } from "@opencode-ai/sdk/v2"
11+
import type { ProviderAuthAuthorization, ProviderAuthMethod } from "@opencode-ai/sdk/v2"
1212
import { DialogModel } from "./dialog-model"
1313
import { useKeyboard } from "@opentui/solid"
1414
import { Clipboard } from "@tui/util/clipboard"
@@ -27,6 +27,7 @@ export function createDialogProviderOptions() {
2727
const sync = useSync()
2828
const dialog = useDialog()
2929
const sdk = useSDK()
30+
const toast = useToast()
3031
const options = createMemo(() => {
3132
return pipe(
3233
sync.data.provider_next.all,
@@ -69,10 +70,29 @@ export function createDialogProviderOptions() {
6970
if (index == null) return
7071
const method = methods[index]
7172
if (method.type === "oauth") {
73+
let inputs: Record<string, string> | undefined
74+
if (method.prompts?.length) {
75+
const value = await PromptsMethod({
76+
dialog,
77+
prompts: method.prompts,
78+
})
79+
if (!value) return
80+
inputs = value
81+
}
82+
7283
const result = await sdk.client.provider.oauth.authorize({
7384
providerID: provider.id,
7485
method: index,
86+
inputs,
7587
})
88+
if (result.error) {
89+
toast.show({
90+
variant: "error",
91+
message: JSON.stringify(result.error),
92+
})
93+
dialog.clear()
94+
return
95+
}
7696
if (result.data?.method === "code") {
7797
dialog.replace(() => (
7898
<CodeMethod providerID={provider.id} title={method.label} index={index} authorization={result.data!} />
@@ -257,3 +277,53 @@ function ApiMethod(props: ApiMethodProps) {
257277
/>
258278
)
259279
}
280+
281+
interface PromptsMethodProps {
282+
dialog: ReturnType<typeof useDialog>
283+
prompts: NonNullable<ProviderAuthMethod["prompts"]>[number][]
284+
}
285+
async function PromptsMethod(props: PromptsMethodProps) {
286+
const inputs: Record<string, string> = {}
287+
for (const prompt of props.prompts) {
288+
if (prompt.when) {
289+
const value = inputs[prompt.when.key]
290+
if (value === undefined) continue
291+
const matches = prompt.when.op === "eq" ? value === prompt.when.value : value !== prompt.when.value
292+
if (!matches) continue
293+
}
294+
295+
if (prompt.type === "select") {
296+
const value = await new Promise<string | null>((resolve) => {
297+
props.dialog.replace(
298+
() => (
299+
<DialogSelect
300+
title={prompt.message}
301+
options={prompt.options.map((x) => ({
302+
title: x.label,
303+
value: x.value,
304+
description: x.hint,
305+
}))}
306+
onSelect={(option) => resolve(option.value)}
307+
/>
308+
),
309+
() => resolve(null),
310+
)
311+
})
312+
if (value === null) return null
313+
inputs[prompt.key] = value
314+
continue
315+
}
316+
317+
const value = await new Promise<string | null>((resolve) => {
318+
props.dialog.replace(
319+
() => (
320+
<DialogPrompt title={prompt.message} placeholder={prompt.placeholder} onConfirm={(value) => resolve(value)} />
321+
),
322+
() => resolve(null),
323+
)
324+
})
325+
if (value === null) return null
326+
inputs[prompt.key] = value
327+
}
328+
return inputs
329+
}

packages/opencode/src/plugin/copilot.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ export async function CopilotAuthPlugin(input: PluginInput): Promise<Hooks> {
168168
key: "enterpriseUrl",
169169
message: "Enter your GitHub Enterprise URL or domain",
170170
placeholder: "company.ghe.com or https://company.ghe.com",
171-
condition: (inputs) => inputs.deploymentType === "enterprise",
171+
when: { key: "deploymentType", op: "eq", value: "enterprise" },
172172
validate: (value) => {
173173
if (!value) return "URL or domain is required"
174174
try {

packages/opencode/src/provider/auth-service.ts

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,44 @@ export const Method = z
1010
.object({
1111
type: z.union([z.literal("oauth"), z.literal("api")]),
1212
label: z.string(),
13+
prompts: z
14+
.array(
15+
z.union([
16+
z.object({
17+
type: z.literal("text"),
18+
key: z.string(),
19+
message: z.string(),
20+
placeholder: z.string().optional(),
21+
when: z
22+
.object({
23+
key: z.string(),
24+
op: z.union([z.literal("eq"), z.literal("neq")]),
25+
value: z.string(),
26+
})
27+
.optional(),
28+
}),
29+
z.object({
30+
type: z.literal("select"),
31+
key: z.string(),
32+
message: z.string(),
33+
options: z.array(
34+
z.object({
35+
label: z.string(),
36+
value: z.string(),
37+
hint: z.string().optional(),
38+
}),
39+
),
40+
when: z
41+
.object({
42+
key: z.string(),
43+
op: z.union([z.literal("eq"), z.literal("neq")]),
44+
value: z.string(),
45+
})
46+
.optional(),
47+
}),
48+
]),
49+
)
50+
.optional(),
1351
})
1452
.meta({
1553
ref: "ProviderAuthMethod",
@@ -43,16 +81,29 @@ export const OauthCodeMissing = NamedError.create(
4381

4482
export const OauthCallbackFailed = NamedError.create("ProviderAuthOauthCallbackFailed", z.object({}))
4583

84+
export const ValidationFailed = NamedError.create(
85+
"ProviderAuthValidationFailed",
86+
z.object({
87+
field: z.string(),
88+
message: z.string(),
89+
}),
90+
)
91+
4692
export type ProviderAuthError =
4793
| Auth.AuthServiceError
4894
| InstanceType<typeof OauthMissing>
4995
| InstanceType<typeof OauthCodeMissing>
5096
| InstanceType<typeof OauthCallbackFailed>
97+
| InstanceType<typeof ValidationFailed>
5198

5299
export namespace ProviderAuthService {
53100
export interface Service {
54101
readonly methods: () => Effect.Effect<Record<string, Method[]>>
55-
readonly authorize: (input: { providerID: ProviderID; method: number }) => Effect.Effect<Authorization | undefined>
102+
readonly authorize: (input: {
103+
providerID: ProviderID
104+
method: number
105+
inputs?: Record<string, string>
106+
}) => Effect.Effect<Authorization | undefined, ProviderAuthError>
56107
readonly callback: (input: {
57108
providerID: ProviderID
58109
method: number
@@ -80,16 +131,52 @@ export class ProviderAuthService extends ServiceMap.Service<ProviderAuthService,
80131
const pending = new Map<ProviderID, AuthOuathResult>()
81132

82133
const methods = Effect.fn("ProviderAuthService.methods")(function* () {
83-
return Record.map(hooks, (item) => item.methods.map((method): Method => Struct.pick(method, ["type", "label"])))
134+
return Record.map(hooks, (item) =>
135+
item.methods.map(
136+
(method): Method => ({
137+
type: method.type,
138+
label: method.label,
139+
prompts: method.prompts?.map((prompt) => {
140+
if (prompt.type === "select") {
141+
return {
142+
type: "select" as const,
143+
key: prompt.key,
144+
message: prompt.message,
145+
options: prompt.options,
146+
when: prompt.when,
147+
}
148+
}
149+
return {
150+
type: "text" as const,
151+
key: prompt.key,
152+
message: prompt.message,
153+
placeholder: prompt.placeholder,
154+
when: prompt.when,
155+
}
156+
}),
157+
}),
158+
),
159+
)
84160
})
85161

86162
const authorize = Effect.fn("ProviderAuthService.authorize")(function* (input: {
87163
providerID: ProviderID
88164
method: number
165+
inputs?: Record<string, string>
89166
}) {
90167
const method = hooks[input.providerID].methods[input.method]
91168
if (method.type !== "oauth") return
92-
const result = yield* Effect.promise(() => method.authorize())
169+
170+
if (method.prompts && input.inputs) {
171+
for (const prompt of method.prompts) {
172+
if (prompt.type === "text" && prompt.validate && input.inputs[prompt.key] !== undefined) {
173+
const error = prompt.validate(input.inputs[prompt.key])
174+
if (error) return yield* Effect.fail(new ValidationFailed({ field: prompt.key, message: error }))
175+
}
176+
}
177+
}
178+
179+
const result = yield* Effect.promise(() => method.authorize(input.inputs))
93180
pending.set(input.providerID, result)
94181
return {
95182
url: result.url,

packages/opencode/src/provider/auth.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export namespace ProviderAuth {
2020
z.object({
2121
providerID: ProviderID.zod,
2222
method: z.number(),
23+
inputs: z.record(z.string(), z.string()).optional(),
2324
}),
2425
async (input): Promise<Authorization | undefined> =>
2526
runPromiseInstance(S.ProviderAuthService.use((service) => service.authorize(input))),
@@ -37,4 +38,5 @@ export namespace ProviderAuth {
3738
export import OauthMissing = S.OauthMissing
3839
export import OauthCodeMissing = S.OauthCodeMissing
3940
export import OauthCallbackFailed = S.OauthCallbackFailed
41+
export import ValidationFailed = S.ValidationFailed
4042
}

packages/opencode/src/server/routes/provider.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,16 @@ export const ProviderRoutes = lazy(() =>
109109
"json",
110110
z.object({
111111
method: z.number().meta({ description: "Auth method index" }),
112+
inputs: z.record(z.string(), z.string()).optional().meta({ description: "Prompt inputs" }),
112113
}),
113114
),
114115
async (c) => {
115116
const providerID = c.req.valid("param").providerID
116-
const { method } = c.req.valid("json")
117+
const { method, inputs } = c.req.valid("json")
117118
const result = await ProviderAuth.authorize({
118119
providerID,
119120
method,
121+
inputs,
120122
})
121123
return c.json(result)
122124
},

packages/opencode/src/server/server.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export namespace Server {
6666
let status: ContentfulStatusCode
6767
if (err instanceof NotFoundError) status = 404
6868
else if (err instanceof Provider.ModelNotFoundError) status = 400
69+
else if (err.name === "ProviderAuthValidationFailed") status = 400
6970
else if (err.name.startsWith("Worktree")) status = 400
7071
else status = 500
7172
return c.json(err.toObject(), { status })

packages/plugin/src/index.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ export type PluginInput = {
3434

3535
export type Plugin = (input: PluginInput) => Promise<Hooks>
3636

37+
type Rule = {
38+
key: string
39+
op: "eq" | "neq"
40+
value: string
41+
}
42+
3743
export type AuthHook = {
3844
provider: string
3945
loader?: (auth: () => Promise<Auth>, provider: Provider) => Promise<Record<string, any>>
@@ -48,7 +54,9 @@ export type AuthHook = {
4854
message: string
4955
placeholder?: string
5056
validate?: (value: string) => string | undefined
57+
/** @deprecated Use `when` instead */
5158
condition?: (inputs: Record<string, string>) => boolean
59+
when?: Rule
5260
}
5361
| {
5462
type: "select"
@@ -59,7 +67,9 @@ export type AuthHook = {
5967
value: string
6068
hint?: string
6169
}>
70+
/** @deprecated Use `when` instead */
6271
condition?: (inputs: Record<string, string>) => boolean
72+
when?: Rule
6373
}
6474
>
6575
authorize(inputs?: Record<string, string>): Promise<AuthOuathResult>
@@ -74,7 +84,9 @@ export type AuthHook = {
7484
message: string
7585
placeholder?: string
7686
validate?: (value: string) => string | undefined
87+
/** @deprecated Use `when` instead */
7788
condition?: (inputs: Record<string, string>) => boolean
89+
when?: Rule
7890
}
7991
| {
8092
type: "select"
@@ -85,7 +97,9 @@ export type AuthHook = {
8597
value: string
8698
hint?: string
8799
}>
100+
/** @deprecated Use `when` instead */
88101
condition?: (inputs: Record<string, string>) => boolean
102+
when?: Rule
89103
}
90104
>
91105
authorize?(inputs?: Record<string, string>): Promise<

packages/sdk/js/src/v2/gen/sdk.gen.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,6 +2605,9 @@ export class Oauth extends HeyApiClient {
26052605
directory?: string
26062606
workspace?: string
26072607
method?: number
2608+
inputs?: {
2609+
[key: string]: string
2610+
}
26082611
},
26092612
options?: Options<never, ThrowOnError>,
26102613
) {
@@ -2617,6 +2620,7 @@ export class Oauth extends HeyApiClient {
26172620
{ in: "query", key: "directory" },
26182621
{ in: "query", key: "workspace" },
26192622
{ in: "body", key: "method" },
2623+
{ in: "body", key: "inputs" },
26202624
],
26212625
},
26222626
],

0 commit comments

Comments
 (0)