@@ -12,28 +12,41 @@ import (
1212
1313 openai "github.com/gptscript-ai/chat-completion-client"
1414 "github.com/gptscript-ai/gptscript/pkg/cache"
15+ "github.com/gptscript-ai/gptscript/pkg/config"
1516 "github.com/gptscript-ai/gptscript/pkg/counter"
17+ "github.com/gptscript-ai/gptscript/pkg/credentials"
1618 "github.com/gptscript-ai/gptscript/pkg/hash"
19+ "github.com/gptscript-ai/gptscript/pkg/prompt"
1720 "github.com/gptscript-ai/gptscript/pkg/system"
1821 "github.com/gptscript-ai/gptscript/pkg/types"
1922)
2023
2124const (
22- DefaultModel = openai .GPT4o
25+ DefaultModel = openai .GPT4o
26+ BuiltinCredName = "sys.openai"
2327)
2428
2529var (
2630 key = os .Getenv ("OPENAI_API_KEY" )
2731 url = os .Getenv ("OPENAI_URL" )
2832)
2933
34+ type InvalidAuthError struct {}
35+
36+ func (InvalidAuthError ) Error () string {
37+ return "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable"
38+ }
39+
3040type Client struct {
3141 defaultModel string
3242 c * openai.Client
3343 cache * cache.Client
3444 invalidAuth bool
3545 cacheKeyBase string
3646 setSeed bool
47+ cliCfg * config.CLIConfig
48+ credCtx string
49+ envs []string
3750}
3851
3952type Options struct {
@@ -75,12 +88,28 @@ func complete(opts ...Options) (result Options, err error) {
7588 return result , err
7689}
7790
78- func NewClient (opts ... Options ) (* Client , error ) {
91+ func NewClient (cliCfg * config. CLIConfig , credCtx string , opts ... Options ) (* Client , error ) {
7992 opt , err := complete (opts ... )
8093 if err != nil {
8194 return nil , err
8295 }
8396
97+ // If the API key is not set, try to get it from the cred store
98+ if opt .APIKey == "" && opt .BaseURL == "" {
99+ store , err := credentials .NewStore (cliCfg , credCtx )
100+ if err != nil {
101+ return nil , err
102+ }
103+
104+ cred , exists , err := store .Get (BuiltinCredName )
105+ if err != nil {
106+ return nil , err
107+ }
108+ if exists {
109+ opt .APIKey = cred .Env ["OPENAI_API_KEY" ]
110+ }
111+ }
112+
84113 cfg := openai .DefaultConfig (opt .APIKey )
85114 cfg .BaseURL = types .FirstSet (opt .BaseURL , cfg .BaseURL )
86115 cfg .OrgID = types .FirstSet (opt .OrgID , cfg .OrgID )
@@ -97,21 +126,33 @@ func NewClient(opts ...Options) (*Client, error) {
97126 cacheKeyBase : cacheKeyBase ,
98127 invalidAuth : opt .APIKey == "" && opt .BaseURL == "" ,
99128 setSeed : opt .SetSeed ,
129+ cliCfg : cliCfg ,
130+ credCtx : credCtx ,
100131 }, nil
101132}
102133
103134func (c * Client ) ValidAuth () error {
104135 if c .invalidAuth {
105- return fmt . Errorf ( "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable" )
136+ return InvalidAuthError {}
106137 }
107138 return nil
108139}
109140
141+ func (c * Client ) SetEnvs (env []string ) {
142+ c .envs = env
143+ }
144+
110145func (c * Client ) Supports (ctx context.Context , modelName string ) (bool , error ) {
111146 models , err := c .ListModels (ctx )
112147 if err != nil {
113148 return false , err
114149 }
150+
151+ if len (models ) == 0 {
152+ // We got no models back, which means our auth is invalid.
153+ return false , InvalidAuthError {}
154+ }
155+
115156 return slices .Contains (models , modelName ), nil
116157}
117158
@@ -121,8 +162,9 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
121162 return nil , nil
122163 }
123164
165+ // If auth is invalid, we just want to return nothing.
124166 if err := c .ValidAuth (); err != nil {
125- return nil , err
167+ return nil , nil
126168 }
127169
128170 models , err := c .c .ListModels (ctx )
@@ -251,7 +293,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
251293
252294func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
253295 if err := c .ValidAuth (); err != nil {
254- return nil , err
296+ if err := c .RetrieveAPIKey (ctx ); err != nil {
297+ return nil , err
298+ }
255299 }
256300
257301 if messageRequest .Model == "" {
@@ -499,6 +543,17 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
499543 }
500544}
501545
546+ func (c * Client ) RetrieveAPIKey (ctx context.Context ) error {
547+ k , err := prompt .GetModelProviderCredential (ctx , BuiltinCredName , "OPENAI_API_KEY" , "Please provide your OpenAI API key:" , c .credCtx , c .envs , c .cliCfg )
548+ if err != nil {
549+ return err
550+ }
551+
552+ c .c .SetAPIKey (k )
553+ c .invalidAuth = false
554+ return nil
555+ }
556+
502557func ptr [T any ](v T ) * T {
503558 return & v
504559}
0 commit comments