@@ -134,21 +134,25 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
134134 }()
135135
136136 callCtx := engine .NewContext (ctx , & prg )
137- if state == nil {
138- startResult , err := r .start (callCtx , monitor , env , input )
137+ if state == nil || state .StartContinuation {
138+ if state != nil {
139+ state = state .WithResumeInput (& input )
140+ input = state .InputContextContinuationInput
141+ }
142+ state , err = r .start (callCtx , state , monitor , env , input )
139143 if err != nil {
140144 return resp , err
141145 }
142- state = & State {
143- Continuation : startResult ,
144- }
145146 } else {
147+ state = state .WithResumeInput (& input )
146148 state .ResumeInput = & input
147149 }
148150
149- state , err = r .resume (callCtx , monitor , env , state )
150- if err != nil {
151- return resp , err
151+ if ! state .StartContinuation {
152+ state , err = r .resume (callCtx , monitor , env , state )
153+ if err != nil {
154+ return resp , err
155+ }
152156 }
153157
154158 if state .Result != nil {
@@ -286,44 +290,79 @@ func getContextInput(prg *types.Program, ref types.ToolReference, input string)
286290 return string (output ), err
287291}
288292
289- func (r * Runner ) getContext (callCtx engine.Context , monitor Monitor , env []string , input string ) (result []engine.InputContext , _ error ) {
293+ func (r * Runner ) getContext (callCtx engine.Context , state * State , monitor Monitor , env []string , input string ) (result []engine.InputContext , _ * State , _ error ) {
290294 toolRefs , err := callCtx .Program .GetContextToolRefs (callCtx .Tool .ID )
291295 if err != nil {
292- return nil , err
296+ return nil , nil , err
293297 }
294298
295- for _ , toolRef := range toolRefs {
299+ var newState * State
300+ if state != nil {
301+ cp := * state
302+ newState = & cp
303+ if newState .InputContextContinuation != nil {
304+ newState .InputContexts = nil
305+ newState .InputContextContinuation = nil
306+ newState .InputContextContinuationInput = ""
307+ newState .ResumeInput = state .InputContextContinuationResumeInput
308+
309+ input = state .InputContextContinuationInput
310+ }
311+ }
312+
313+ for i , toolRef := range toolRefs {
314+ if state != nil && i < len (state .InputContexts ) {
315+ result = append (result , state .InputContexts [i ])
316+ continue
317+ }
318+
296319 contextInput , err := getContextInput (callCtx .Program , toolRef , input )
297320 if err != nil {
298- return nil , err
321+ return nil , nil , err
299322 }
300323
301- content , err := r .subCall (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , contextInput , "" , engine .ContextToolCategory )
324+ var content * State
325+ if state != nil && state .InputContextContinuation != nil {
326+ content , err = r .subCallResume (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , "" , state .InputContextContinuation .WithResumeInput (state .ResumeInput ), engine .ContextToolCategory )
327+ } else {
328+ content , err = r .subCall (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , contextInput , "" , engine .ContextToolCategory )
329+ }
302330 if err != nil {
303- return nil , err
331+ return nil , nil , err
304332 }
305- if content .Result == nil {
306- return nil , fmt .Errorf ("context tool can not result in a chat continuation" )
333+ if content .Continuation != nil {
334+ if newState == nil {
335+ newState = & State {}
336+ }
337+ newState .InputContexts = result
338+ newState .InputContextContinuation = content
339+ newState .InputContextContinuationInput = input
340+ if state != nil {
341+ newState .InputContextContinuationResumeInput = state .ResumeInput
342+ }
343+ return nil , newState , nil
307344 }
308345 result = append (result , engine.InputContext {
309346 ToolID : toolRef .ToolID ,
310347 Content : * content .Result ,
311348 })
312349 }
313- return result , nil
350+
351+ return result , newState , nil
314352}
315353
316354func (r * Runner ) call (callCtx engine.Context , monitor Monitor , env []string , input string ) (* State , error ) {
317- result , err := r .start (callCtx , monitor , env , input )
355+ result , err := r .start (callCtx , nil , monitor , env , input )
318356 if err != nil {
319357 return nil , err
320358 }
321- return r .resume (callCtx , monitor , env , & State {
322- Continuation : result ,
323- })
359+ if result .StartContinuation {
360+ return result , nil
361+ }
362+ return r .resume (callCtx , monitor , env , result )
324363}
325364
326- func (r * Runner ) start (callCtx engine.Context , monitor Monitor , env []string , input string ) (* engine. Return , error ) {
365+ func (r * Runner ) start (callCtx engine.Context , state * State , monitor Monitor , env []string , input string ) (* State , error ) {
327366 progress , progressClose := streamProgress (& callCtx , monitor )
328367 defer progressClose ()
329368
@@ -335,11 +374,18 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
335374 }
336375 }
337376
338- var err error
339- callCtx .InputContext , err = r .getContext (callCtx , monitor , env , input )
377+ var (
378+ err error
379+ newState * State
380+ )
381+ callCtx .InputContext , newState , err = r .getContext (callCtx , state , monitor , env , input )
340382 if err != nil {
341383 return nil , err
342384 }
385+ if newState != nil && newState .InputContextContinuation != nil {
386+ newState .StartContinuation = true
387+ return newState , nil
388+ }
343389
344390 e := engine.Engine {
345391 Model : r .c ,
@@ -358,7 +404,14 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
358404
359405 callCtx .Ctx = context2 .AddPauseFuncToCtx (callCtx .Ctx , monitor .Pause )
360406
361- return e .Start (callCtx , input )
407+ ret , err := e .Start (callCtx , input )
408+ if err != nil {
409+ return nil , err
410+ }
411+
412+ return & State {
413+ Continuation : ret ,
414+ }, nil
362415}
363416
364417type State struct {
@@ -369,18 +422,28 @@ type State struct {
369422 ResumeInput * string `json:"resumeInput,omitempty"`
370423 SubCalls []SubCallResult `json:"subCalls,omitempty"`
371424 SubCallID string `json:"subCallID,omitempty"`
425+
426+ InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
427+ InputContextContinuation * State `json:"inputContextContinuation,omitempty"`
428+ InputContextContinuationInput string `json:"inputContextContinuationInput,omitempty"`
429+ InputContextContinuationResumeInput * string `json:"inputContextContinuationResumeInput,omitempty"`
430+ StartContinuation bool `json:"startContinuation,omitempty"`
372431}
373432
374- func (s State ) WithInput (input * string ) * State {
433+ func (s State ) WithResumeInput (input * string ) * State {
375434 s .ResumeInput = input
376435 return & s
377436}
378437
379438func (s State ) ContinuationContentToolID () (string , error ) {
380- if s .Continuation .Result != nil {
439+ if s .Continuation != nil && s . Continuation .Result != nil {
381440 return s .ContinuationToolID , nil
382441 }
383442
443+ if s .InputContextContinuation != nil {
444+ return s .InputContextContinuation .ContinuationContentToolID ()
445+ }
446+
384447 for _ , subCall := range s .SubCalls {
385448 if s .SubCallID == subCall .CallID {
386449 return subCall .State .ContinuationContentToolID ()
@@ -390,10 +453,14 @@ func (s State) ContinuationContentToolID() (string, error) {
390453}
391454
392455func (s State ) ContinuationContent () (string , error ) {
393- if s .Continuation .Result != nil {
456+ if s .Continuation != nil && s . Continuation .Result != nil {
394457 return * s .Continuation .Result , nil
395458 }
396459
460+ if s .InputContextContinuation != nil {
461+ return s .InputContextContinuation .ContinuationContent ()
462+ }
463+
397464 for _ , subCall := range s .SubCalls {
398465 if s .SubCallID == subCall .CallID {
399466 return subCall .State .ContinuationContent ()
@@ -408,6 +475,10 @@ type Needed struct {
408475}
409476
410477func (r * Runner ) resume (callCtx engine.Context , monitor Monitor , env []string , state * State ) (* State , error ) {
478+ if state .StartContinuation {
479+ return nil , fmt .Errorf ("invalid state, resume should not have StartContinuation set to true" )
480+ }
481+
411482 progress , progressClose := streamProgress (& callCtx , monitor )
412483 defer progressClose ()
413484
@@ -451,7 +522,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
451522 err error
452523 )
453524
454- state , callResults , err = r .subCalls (callCtx , monitor , env , state )
525+ state , callResults , err = r .subCalls (callCtx , monitor , env , state , engine . NoCategory )
455526 if errMessage := (* builtin .ErrChatFinish )(nil ); errors .As (err , & errMessage ) && callCtx .Tool .Chat {
456527 return & State {
457528 Result : & errMessage .Message ,
@@ -477,12 +548,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
477548 }
478549 }
479550
480- if state .ResumeInput != nil {
481- engineResults = append (engineResults , engine.CallResult {
482- User : * state .ResumeInput ,
483- })
484- }
485-
486551 monitor .Event (Event {
487552 Time : time .Now (),
488553 CallContext : callCtx .GetCallContext (),
@@ -506,9 +571,15 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
506571 contentInput = state .Continuation .State .Input
507572 }
508573
509- callCtx .InputContext , err = r .getContext (callCtx , monitor , env , contentInput )
510- if err != nil {
511- return nil , err
574+ callCtx .InputContext , state , err = r .getContext (callCtx , state , monitor , env , contentInput )
575+ if err != nil || state .InputContextContinuation != nil {
576+ return state , err
577+ }
578+
579+ if state .ResumeInput != nil {
580+ engineResults = append (engineResults , engine.CallResult {
581+ User : * state .ResumeInput ,
582+ })
512583 }
513584
514585 nextContinuation , err := e .Continue (callCtx , state .Continuation .State , engineResults ... )
@@ -571,8 +642,8 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni
571642 return r .call (callCtx , monitor , env , input )
572643}
573644
574- func (r * Runner ) subCallResume (ctx context.Context , parentContext engine.Context , monitor Monitor , env []string , toolID , callID string , state * State ) (* State , error ) {
575- callCtx , err := parentContext .SubCall (ctx , toolID , callID , engine . NoCategory )
645+ func (r * Runner ) subCallResume (ctx context.Context , parentContext engine.Context , monitor Monitor , env []string , toolID , callID string , state * State , toolCategory engine. ToolCategory ) (* State , error ) {
646+ callCtx , err := parentContext .SubCall (ctx , toolID , callID , toolCategory )
576647 if err != nil {
577648 return nil , err
578649 }
@@ -593,11 +664,15 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
593664 return newParallelDispatcher (ctx )
594665}
595666
596- func (r * Runner ) subCalls (callCtx engine.Context , monitor Monitor , env []string , state * State ) (_ * State , callResults []SubCallResult , _ error ) {
667+ func (r * Runner ) subCalls (callCtx engine.Context , monitor Monitor , env []string , state * State , toolCategory engine. ToolCategory ) (_ * State , callResults []SubCallResult , _ error ) {
597668 var (
598669 resultLock sync.Mutex
599670 )
600671
672+ if state .InputContextContinuation != nil {
673+ return state , nil , nil
674+ }
675+
601676 if state .SubCallID != "" {
602677 if state .ResumeInput == nil {
603678 return nil , nil , fmt .Errorf ("invalid state, input must be set for sub call continuation on callID [%s]" , state .SubCallID )
@@ -608,7 +683,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
608683 found = true
609684 subState := * subCall .State
610685 subState .ResumeInput = state .ResumeInput
611- result , err := r .subCallResume (callCtx .Ctx , callCtx , monitor , env , subCall .ToolID , subCall .CallID , subCall .State .WithInput (state .ResumeInput ))
686+ result , err := r .subCallResume (callCtx .Ctx , callCtx , monitor , env , subCall .ToolID , subCall .CallID , subCall .State .WithResumeInput (state .ResumeInput ), toolCategory )
612687 if err != nil {
613688 return nil , nil , err
614689 }
@@ -618,7 +693,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
618693 State : result ,
619694 })
620695 // Clear the input, we have already processed it
621- state = state .WithInput (nil )
696+ state = state .WithResumeInput (nil )
622697 } else {
623698 callResults = append (callResults , subCall )
624699 }
0 commit comments