diff --git a/components/backend/handlers/sessions.go b/components/backend/handlers/sessions.go index 290df07bd..1c40ac8da 100755 --- a/components/backend/handlers/sessions.go +++ b/components/backend/handlers/sessions.go @@ -13,6 +13,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "sort" "strings" "sync" @@ -1474,6 +1475,172 @@ func UpdateSessionDisplayName(c *gin.Context) { c.JSON(http.StatusOK, session) } +// SwitchModel switches the LLM model for a running session +// POST /api/projects/:projectName/agentic-sessions/:sessionName/model +func SwitchModel(c *gin.Context) { + project := c.GetString("project") + sessionName := c.Param("sessionName") + _, k8sDyn := GetK8sClientsForRequest(c) + if k8sDyn == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"}) + c.Abort() + return + } + + var req struct { + Model string `json:"model" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body: model is required"}) + return + } + + if req.Model == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "model must not be empty"}) + return + } + + gvr := GetAgenticSessionV1Alpha1Resource() + + // Get current session + item, err := k8sDyn.Resource(gvr).Namespace(project).Get(context.TODO(), sessionName, v1.GetOptions{}) + if err != nil { + if errors.IsNotFound(err) { + c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get session"}) + return + } + + // Ensure session is Running + if err := ensureRuntimeMutationAllowed(item); err != nil { + c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) + return + } + + // Get current model for comparison + spec, ok := item.Object["spec"].(map[string]interface{}) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid session spec"}) + return + } + llmSettings, _, _ := unstructured.NestedMap(spec, "llmSettings") + previousModel, _ := llmSettings["model"].(string) + + // No-op if same model + if previousModel == req.Model { + session := types.AgenticSession{ + APIVersion: item.GetAPIVersion(), + Kind: item.GetKind(), + } + if meta, ok := item.Object["metadata"].(map[string]interface{}); ok { + session.Metadata = meta + } + session.Spec = parseSpec(spec) + if status, ok := item.Object["status"].(map[string]interface{}); ok { + session.Status = parseStatus(status) + } + c.JSON(http.StatusOK, session) + return + } + + // Update the CR first to validate RBAC (user needs update permission). + // This ensures a user with only get access cannot trigger a runner-side + // model switch without also being allowed to persist the change. + if llmSettings == nil { + llmSettings = map[string]interface{}{} + } + llmSettings["model"] = req.Model + spec["llmSettings"] = llmSettings + + updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{}) + if err != nil { + log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"}) + return + } + + // Proxy to runner — if runner rejects (e.g., agent is mid-generation), revert the CR. + // Sanitize the CR name against a strict allowlist to prevent SSRF. + sanitizedName, err := sanitizeK8sName(item.GetName()) + if err != nil { + log.Printf("Invalid session name %q for model switch: %v", item.GetName(), err) + revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session name"}) + return + } + sanitizedProject, err := sanitizeK8sName(project) + if err != nil { + log.Printf("Invalid project name %q for model switch: %v", project, err) + revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"}) + return + } + serviceName := getRunnerServiceName(sanitizedName) + runnerURL := fmt.Sprintf("http://%s.%s.svc.cluster.local:8001/model", serviceName, sanitizedProject) + runnerReq := map[string]string{"model": req.Model} + reqBody, _ := json.Marshal(runnerReq) + + httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", runnerURL, bytes.NewReader(reqBody)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create runner request"}) + return + } + httpReq.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(httpReq) + if err != nil { + log.Printf("Failed to proxy model switch to runner for session %s: %v", sessionName, err) + // Revert the CR update on the server-returned object + revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) + c.JSON(http.StatusBadGateway, gin.H{"error": "Failed to reach session runner"}) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + log.Printf("Runner rejected model switch for session %s: %d %s", sessionName, resp.StatusCode, string(body)) + // Revert the CR update on the server-returned object + revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) + // Forward runner's status code and error + c.Data(resp.StatusCode, "application/json", body) + return + } + + session := types.AgenticSession{ + APIVersion: updated.GetAPIVersion(), + Kind: updated.GetKind(), + } + if meta, ok := updated.Object["metadata"].(map[string]interface{}); ok { + session.Metadata = meta + } + if s, ok := updated.Object["spec"].(map[string]interface{}); ok { + session.Spec = parseSpec(s) + } + if status, ok := updated.Object["status"].(map[string]interface{}); ok { + session.Status = parseStatus(status) + } + + c.JSON(http.StatusOK, session) +} + +// revertModelSwitch restores the previous model on the server-returned CR object. +// Called when the runner rejects a model switch after the CR was already updated. +func revertModelSwitch(updated *unstructured.Unstructured, previousModel string, k8sDyn dynamic.Interface, gvr schema.GroupVersionResource, namespace string) { + if updatedSpec, ok := updated.Object["spec"].(map[string]interface{}); ok { + if updatedLLM, ok := updatedSpec["llmSettings"].(map[string]interface{}); ok { + updatedLLM["model"] = previousModel + _, err := k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), updated, v1.UpdateOptions{}) + if err != nil { + log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err) + } + } + } +} + // SelectWorkflow sets the active workflow for a session // POST /api/projects/:projectName/agentic-sessions/:sessionName/workflow func SelectWorkflow(c *gin.Context) { @@ -1945,6 +2112,20 @@ func RemoveRepo(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "Repository removed", "session": session}) } +// k8sNameRegexp matches valid Kubernetes resource names (RFC 1123 DNS label). +var k8sNameRegexp = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-]*[a-z0-9])?$`) + +// sanitizeK8sName validates that name is a valid Kubernetes resource name +// and returns it unchanged if valid, or returns an error. This breaks the +// taint chain for static analysis (CodeQL SSRF) by proving the value matches +// a strict allowlist before it reaches any network call. +func sanitizeK8sName(name string) (string, error) { + if len(name) == 0 || len(name) > 253 || !k8sNameRegexp.MatchString(name) { + return "", fmt.Errorf("invalid Kubernetes resource name: %q", name) + } + return name, nil +} + // getRunnerServiceName returns the K8s Service name for a session's runner. // The runner serves both AG-UI and content endpoints on port 8001. func getRunnerServiceName(session string) string { diff --git a/components/backend/routes.go b/components/backend/routes.go index 29dac8242..9fb63050f 100755 --- a/components/backend/routes.go +++ b/components/backend/routes.go @@ -59,6 +59,7 @@ func registerRoutes(r *gin.Engine) { projectGroup.GET("/agentic-sessions/:sessionName/repos/status", handlers.GetReposStatus) projectGroup.DELETE("/agentic-sessions/:sessionName/repos/:repoName", handlers.RemoveRepo) projectGroup.PUT("/agentic-sessions/:sessionName/displayname", handlers.UpdateSessionDisplayName) + projectGroup.POST("/agentic-sessions/:sessionName/model", handlers.SwitchModel) // OAuth integration - requires user auth like all other session endpoints projectGroup.GET("/agentic-sessions/:sessionName/oauth/:provider/url", handlers.GetOAuthURL) diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/__tests__/live-model-selector.test.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/__tests__/live-model-selector.test.tsx new file mode 100644 index 000000000..3acc1691a --- /dev/null +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/__tests__/live-model-selector.test.tsx @@ -0,0 +1,67 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen } from '@testing-library/react'; +import { LiveModelSelector } from '../live-model-selector'; +import type { ListModelsResponse } from '@/types/api'; + +const mockAnthropicModels: ListModelsResponse = { + models: [ + { id: 'claude-haiku-4-5', label: 'Claude Haiku 4.5', provider: 'anthropic', isDefault: false }, + { id: 'claude-sonnet-4-5', label: 'Claude Sonnet 4.5', provider: 'anthropic', isDefault: true }, + { id: 'claude-opus-4-6', label: 'Claude Opus 4.6', provider: 'anthropic', isDefault: false }, + ], + defaultModel: 'claude-sonnet-4-5', +}; + +const mockUseModels = vi.fn(() => ({ data: mockAnthropicModels })); + +vi.mock('@/services/queries/use-models', () => ({ + useModels: () => mockUseModels(), +})); + +describe('LiveModelSelector', () => { + const defaultProps = { + projectName: 'test-project', + currentModel: 'claude-sonnet-4-5', + onSelect: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + mockUseModels.mockReturnValue({ data: mockAnthropicModels }); + }); + + it('renders with current model name displayed', () => { + render(); + const button = screen.getByRole('button'); + expect(button.textContent).toContain('Claude Sonnet 4.5'); + }); + + it('renders with model id fallback when model not in list', () => { + render( + + ); + const button = screen.getByRole('button'); + expect(button.textContent).toContain('unknown-model-id'); + }); + + it('shows spinner when switching', () => { + render(); + const spinner = document.querySelector('.animate-spin'); + expect(spinner).not.toBeNull(); + }); + + it('button is disabled when disabled prop is true', () => { + render(); + const button = screen.getByRole('button'); + expect((button as HTMLButtonElement).disabled).toBe(true); + }); + + it('button is disabled when switching prop is true', () => { + render(); + const button = screen.getByRole('button'); + expect((button as HTMLButtonElement).disabled).toBe(true); + }); +}); diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/live-model-selector.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/live-model-selector.tsx new file mode 100644 index 000000000..4c8aec5e0 --- /dev/null +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/components/live-model-selector.tsx @@ -0,0 +1,91 @@ +"use client"; + +import { useMemo } from "react"; +import { ChevronDown, Loader2 } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuRadioGroup, + DropdownMenuRadioItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { useModels } from "@/services/queries/use-models"; + +type LiveModelSelectorProps = { + projectName: string; + currentModel: string; + provider?: string; + disabled?: boolean; + switching?: boolean; + onSelect: (model: string) => void; +}; + +export function LiveModelSelector({ + projectName, + currentModel, + provider, + disabled, + switching, + onSelect, +}: LiveModelSelectorProps) { + const { data: modelsData, isLoading, isError } = useModels(projectName, true, provider); + + const models = useMemo(() => { + return modelsData?.models.map((m) => ({ id: m.id, name: m.label })) ?? []; + }, [modelsData]); + + const currentModelName = + models.find((m) => m.id === currentModel)?.name ?? currentModel; + + return ( + + + + + + {isLoading ? ( +
+ +
+ ) : isError ? ( +
+ Failed to load models +
+ ) : models.length > 0 ? ( + { + if (modelId !== currentModel) { + onSelect(modelId); + } + }} + > + {models.map((model) => ( + + {model.name} + + ))} + + ) : ( +
+ No models available +
+ )} +
+
+ ); +} diff --git a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx index 2a6b6d2d0..76db8ec55 100755 --- a/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx +++ b/components/frontend/src/app/projects/[name]/sessions/[sessionName]/page.tsx @@ -30,6 +30,7 @@ import { TaskTranscriptViewer } from "./components/task-transcript-viewer"; import { ExplorerPanel } from "./components/explorer/explorer-panel"; import { SessionSettingsModal } from "./components/session-settings-modal"; import { WorkflowSelector } from "./components/workflow-selector"; +import { LiveModelSelector } from "./components/live-model-selector"; import { useExplorerState } from "./hooks/use-explorer-state"; import { useFileTabs } from "./hooks/use-file-tabs"; @@ -58,6 +59,7 @@ import { useCurrentUser, sessionKeys, useRunnerTypes, + useSwitchSessionModel, } from "@/services/queries"; import { useCapabilities } from "@/services/queries/use-capabilities"; import { @@ -188,13 +190,13 @@ export default function ProjectSessionDetailPage({ // Fetch runner capabilities and derive agent display name const { data: capabilities } = useCapabilities(projectName, sessionName, phase === "Running"); const { data: runnerTypes } = useRunnerTypes(projectName); - const agentName = useMemo(() => { + const currentRunner = useMemo(() => { if (capabilities?.framework && runnerTypes) { - const matched = runnerTypes.find((rt) => rt.id === capabilities.framework); - if (matched) return matched.displayName; + return runnerTypes.find((rt) => rt.id === capabilities.framework); } return undefined; }, [capabilities?.framework, runnerTypes]); + const agentName = currentRunner?.displayName; // Track the current Langfuse trace ID for feedback association const [langfuseTraceId, setLangfuseTraceId] = useState(null); @@ -215,6 +217,7 @@ export default function ProjectSessionDetailPage({ const aguiSendMessage = aguiStream.sendMessage; const aguiInterrupt = aguiStream.interrupt; const isRunActive = aguiStream.isRunActive; + const switchModelMutation = useSwitchSessionModel(); const aguiConnectRef = useRef(aguiStream.connect); // Keep connect ref up to date @@ -739,6 +742,17 @@ export default function ProjectSessionDetailPage({ } }; + const handleModelSwitch = useCallback((model: string) => { + switchModelMutation.mutate( + { projectName, sessionName, model }, + { + onError: (error: Error) => { + toast.error(`Failed to switch model: ${error.message}`); + }, + }, + ); + }, [projectName, sessionName, switchModelMutation]); + // Phase 1: convert committed messages + streaming tool cards into display format. // Does NOT depend on currentMessage / currentReasoning so it skips the full // O(n) traversal during text-streaming deltas (the most frequent event type). @@ -1642,6 +1656,18 @@ export default function ProjectSessionDetailPage({ onAddRepository={handleOpenContextModal} onUploadFile={handleOpenUploadModal} projectName={projectName} + modelSlot={ + phase === "Running" ? ( + + ) : undefined + } workflowSlot={ void; onUploadFile?: () => void; workflowSlot?: React.ReactNode; + modelSlot?: React.ReactNode; projectName?: string; }; @@ -181,6 +182,7 @@ export const ChatInputBox: React.FC = ({ onAddRepository, onUploadFile, workflowSlot, + modelSlot, projectName, }) => { const textareaRef = useRef(null); @@ -667,8 +669,9 @@ export const ChatInputBox: React.FC = ({ - {/* Right side: Workflow selector + Send/Stop buttons */} + {/* Right side: Model + Workflow selector + Send/Stop buttons */}
+ {modelSlot} {workflowSlot} {isRunActive ? ( diff --git a/components/frontend/src/components/session/MessagesTab.tsx b/components/frontend/src/components/session/MessagesTab.tsx old mode 100644 new mode 100755 index 89b9ef5c2..c17494324 --- a/components/frontend/src/components/session/MessagesTab.tsx +++ b/components/frontend/src/components/session/MessagesTab.tsx @@ -53,11 +53,12 @@ export type MessagesTabProps = { onAddRepository?: () => void; onUploadFile?: () => void; workflowSlot?: React.ReactNode; + modelSlot?: React.ReactNode; projectName?: string; }; -const MessagesTab: React.FC = ({ session, streamMessages, chatInput, setChatInput, onSendChat, onSendToolAnswer, onInterrupt, onGoToResults, onContinue, workflowMetadata, onCommandClick, isRunActive = false, showWelcomeExperience, welcomeExperienceComponent, activeWorkflow, userHasInteracted = false, queuedMessages = [], hasRealMessages = false, onCancelQueuedMessage, onUpdateQueuedMessage, onPasteImage, onClearQueue, agentName, onAddRepository, onUploadFile, workflowSlot, projectName }) => { +const MessagesTab: React.FC = ({ session, streamMessages, chatInput, setChatInput, onSendChat, onSendToolAnswer, onInterrupt, onGoToResults, onContinue, workflowMetadata, onCommandClick, isRunActive = false, showWelcomeExperience, welcomeExperienceComponent, activeWorkflow, userHasInteracted = false, queuedMessages = [], hasRealMessages = false, onCancelQueuedMessage, onUpdateQueuedMessage, onPasteImage, onClearQueue, agentName, onAddRepository, onUploadFile, workflowSlot, modelSlot, projectName }) => { const [sendingChat, setSendingChat] = useState(false); const showSystemMessages = false; const [waitingDotCount, setWaitingDotCount] = useState(0); @@ -294,6 +295,7 @@ const MessagesTab: React.FC = ({ session, streamMessages, chat onAddRepository={onAddRepository} onUploadFile={onUploadFile} workflowSlot={workflowSlot} + modelSlot={modelSlot} projectName={projectName} />
diff --git a/components/frontend/src/hooks/agui/event-handlers.ts b/components/frontend/src/hooks/agui/event-handlers.ts old mode 100644 new mode 100755 index 80996bc5c..590e140f5 --- a/components/frontend/src/hooks/agui/event-handlers.ts +++ b/components/frontend/src/hooks/agui/event-handlers.ts @@ -1062,6 +1062,20 @@ function handleCustomEvent( return { ...state, backgroundTasks: tasks } } + // Model switch confirmation — inject a system message into the conversation + if (name === 'ambient:model_switched') { + const previousModel = value.previousModel as string + const newModel = value.newModel as string + const msg: PlatformMessage = { + id: `model-switch-${Date.now()}`, + role: 'assistant', + content: `Model switched from **${previousModel}** to **${newModel}**`, + timestamp: new Date().toISOString(), + metadata: { isModelSwitch: true }, + } + return { ...state, messages: [...state.messages, msg] } + } + // Other custom events (hooks) — pass through unchanged return state } diff --git a/components/frontend/src/services/api/sessions.ts b/components/frontend/src/services/api/sessions.ts index 2cc1fd1cc..e1489c64a 100755 --- a/components/frontend/src/services/api/sessions.ts +++ b/components/frontend/src/services/api/sessions.ts @@ -302,6 +302,20 @@ export async function saveToGoogleDrive( ); } +/** + * Switch the LLM model for a running session + */ +export async function switchSessionModel( + projectName: string, + sessionName: string, + model: string +): Promise { + return apiClient.post( + `/projects/${projectName}/agentic-sessions/${sessionName}/model`, + { model } + ); +} + // --- Capabilities --- export type CapabilitiesResponse = { diff --git a/components/frontend/src/services/queries/use-sessions.ts b/components/frontend/src/services/queries/use-sessions.ts old mode 100644 new mode 100755 index 0a5b89109..19933b0a8 --- a/components/frontend/src/services/queries/use-sessions.ts +++ b/components/frontend/src/services/queries/use-sessions.ts @@ -345,6 +345,31 @@ export function useUpdateSessionDisplayName() { }); } +/** + * Hook to switch the LLM model for a running session + */ +export function useSwitchSessionModel() { + const queryClient = useQueryClient(); + + return useMutation({ + mutationFn: ({ + projectName, + sessionName, + model, + }: { + projectName: string; + sessionName: string; + model: string; + }) => sessionsApi.switchSessionModel(projectName, sessionName, model), + onSuccess: (_data, { projectName, sessionName }) => { + queryClient.invalidateQueries({ + queryKey: sessionKeys.detail(projectName, sessionName), + refetchType: 'all', + }); + }, + }); +} + /** * Hook to fetch session export data (AG-UI events + legacy messages) */ diff --git a/components/runners/ambient-runner/ambient_runner/app.py b/components/runners/ambient-runner/ambient_runner/app.py old mode 100644 new mode 100755 index f73b2d6df..c8a0b464e --- a/components/runners/ambient-runner/ambient_runner/app.py +++ b/components/runners/ambient-runner/ambient_runner/app.py @@ -222,6 +222,10 @@ def add_ambient_endpoints( app.include_router(interrupt_router) app.include_router(health_router) + from ambient_runner.endpoints.model import router as model_router + + app.include_router(model_router) + # Optional platform endpoints if enable_capabilities: from ambient_runner.endpoints.capabilities import router as cap_router diff --git a/components/runners/ambient-runner/ambient_runner/endpoints/model.py b/components/runners/ambient-runner/ambient_runner/endpoints/model.py new file mode 100644 index 000000000..4ff4a0fd0 --- /dev/null +++ b/components/runners/ambient-runner/ambient_runner/endpoints/model.py @@ -0,0 +1,120 @@ +"""POST /model — Switch the LLM model at runtime.""" + +import asyncio +import logging +import os + +from fastapi import APIRouter, HTTPException, Request + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Serialise model changes to prevent concurrent switches +_model_change_lock = asyncio.Lock() + + +@router.post("/model") +async def switch_model(request: Request): + """Switch the LLM model used by this session. + + The agent must be idle (not mid-generation). If a run is in + progress the endpoint returns 422. + """ + bridge = request.app.state.bridge + context = bridge.context + if not context: + raise HTTPException(status_code=503, detail="Context not initialized") + + body = await request.json() + new_model = (body.get("model") or "").strip() + + if not new_model: + raise HTTPException(status_code=400, detail="model is required") + + previous_model = os.getenv("LLM_MODEL", "") + + if new_model == previous_model: + return { + "message": "Model unchanged", + "model": new_model, + } + + # Check if agent is mid-generation. + # The session manager holds a per-thread asyncio.Lock during runs. + session_manager = getattr(bridge, "_session_manager", None) + if session_manager: + thread_id = context.session_id if context else "" + lock = session_manager.get_lock(thread_id) if thread_id else None + if lock and lock.locked(): + raise HTTPException( + status_code=422, + detail="Cannot switch model while agent is generating a response. Wait for the current turn to complete.", + ) + + # Fast-reject if another switch is already in progress. + # asyncio is single-threaded, so no yield between locked() and acquire(). + if _model_change_lock.locked(): + raise HTTPException( + status_code=409, + detail="A model switch is already in progress", + ) + async with _model_change_lock: + return await _perform_model_switch(bridge, context, new_model, previous_model) + + +async def _perform_model_switch(bridge, context, new_model: str, previous_model: str) -> dict: + """Execute the model switch: update env, rebuild adapter, emit event.""" + logger.info(f"Switching model from '{previous_model}' to '{new_model}'") + + # Update environment variable (read by setup_sdk_authentication on next init) + os.environ["LLM_MODEL"] = new_model + + # Also update the Vertex ID mapping if applicable + use_vertex = os.getenv("USE_VERTEX", "").strip().lower() in ("1", "true", "yes") + if use_vertex: + # Clear the manifest override so auth.py re-derives from the new LLM_MODEL + os.environ.pop("LLM_MODEL_VERTEX_ID", None) + + # Emit confirmation event BEFORE mark_dirty destroys the session manager + _emit_model_switched_event(bridge, context, new_model, previous_model) + + # Signal adapter rebuild — stops current workers, preserves session IDs + bridge.mark_dirty() + + logger.info(f"Model switch complete: {previous_model} -> {new_model}") + + return { + "message": "Model switched", + "model": new_model, + "previousModel": previous_model, + } + + +def _emit_model_switched_event(bridge, context, new_model: str, previous_model: str): + """Push a custom AG-UI event to notify the frontend of the model switch.""" + try: + from ag_ui.core import CustomEvent, EventType + + event = CustomEvent( + type=EventType.CUSTOM, + name="ambient:model_switched", + value={ + "previousModel": previous_model, + "newModel": new_model, + }, + ) + + # Route to the between-run event queue so the frontend picks it up + session_manager = getattr(bridge, "_session_manager", None) + if session_manager: + thread_id = context.session_id if context else "" + worker = session_manager.get_existing(thread_id) + if worker: + worker._between_run_queue.put_nowait(event) + logger.info("Model switch event emitted to between-run queue") + return + + logger.warning("No active worker to emit model switch event") + except Exception as e: + logger.warning(f"Failed to emit model switch event: {e}") diff --git a/components/runners/ambient-runner/tests/test_model_endpoint.py b/components/runners/ambient-runner/tests/test_model_endpoint.py new file mode 100644 index 000000000..3f39e5fc0 --- /dev/null +++ b/components/runners/ambient-runner/tests/test_model_endpoint.py @@ -0,0 +1,167 @@ +"""Unit tests for the POST /model endpoint.""" + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from ambient_runner.endpoints.model import router + + +def _make_mock_bridge( + *, + session_id="test-session", + has_context=True, + lock_locked=False, + has_worker=True, +): + """Create a mock bridge with configurable session manager and context.""" + bridge = MagicMock() + + if has_context: + bridge.context = MagicMock() + bridge.context.session_id = session_id + else: + bridge.context = None + + # Session manager with per-thread lock + lock = asyncio.Lock() + if lock_locked: + # Simulate a locked lock by acquiring it (non-async safe for test setup) + lock._locked = True + + session_manager = MagicMock() + session_manager.get_lock.return_value = lock + + if has_worker: + worker = MagicMock() + worker._between_run_queue = asyncio.Queue() + session_manager.get_existing.return_value = worker + else: + session_manager.get_existing.return_value = None + + bridge._session_manager = session_manager + + return bridge + + +@pytest.fixture(autouse=True) +def _reset_model_change_lock(): + """Ensure the module-level _model_change_lock is released between tests.""" + from ambient_runner.endpoints import model as mod + + # Replace with a fresh lock so no test leaks state + mod._model_change_lock = asyncio.Lock() + yield + + +@pytest.fixture +def make_client(): + """Factory to create a test client with a mock bridge.""" + + def _factory(*, env_model="claude-sonnet-4-5", **bridge_kwargs): + app = FastAPI() + app.state.bridge = _make_mock_bridge(**bridge_kwargs) + app.include_router(router) + with patch.dict("os.environ", {"LLM_MODEL": env_model}): + client = TestClient(app) + return client, app.state.bridge + + return _factory + + +class TestModelEndpoint: + """Test POST /model request handling.""" + + def test_success_switches_model(self, make_client): + """POST /model with a valid new model returns 200 with model and previousModel.""" + with patch.dict("os.environ", {"LLM_MODEL": "claude-sonnet-4-5"}): + client, bridge = make_client(env_model="claude-sonnet-4-5") + resp = client.post("/model", json={"model": "claude-opus-4"}) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Model switched" + assert data["model"] == "claude-opus-4" + assert data["previousModel"] == "claude-sonnet-4-5" + bridge.mark_dirty.assert_called_once() + + def test_empty_model_returns_400(self, make_client): + """POST /model with an empty model string returns 400.""" + client, _ = make_client() + resp = client.post("/model", json={"model": ""}) + + assert resp.status_code == 400 + assert "model is required" in resp.json()["detail"] + + def test_whitespace_only_model_returns_400(self, make_client): + """POST /model with whitespace-only model returns 400.""" + client, _ = make_client() + resp = client.post("/model", json={"model": " "}) + + assert resp.status_code == 400 + assert "model is required" in resp.json()["detail"] + + def test_missing_model_field_returns_400(self, make_client): + """POST /model with no model field in body returns 400.""" + client, _ = make_client() + resp = client.post("/model", json={}) + + assert resp.status_code == 400 + assert "model is required" in resp.json()["detail"] + + def test_same_model_returns_unchanged(self, make_client): + """POST /model with same model as current returns 200 with 'Model unchanged'.""" + with patch.dict("os.environ", {"LLM_MODEL": "claude-sonnet-4-5"}): + client, bridge = make_client(env_model="claude-sonnet-4-5") + resp = client.post("/model", json={"model": "claude-sonnet-4-5"}) + + assert resp.status_code == 200 + data = resp.json() + assert data["message"] == "Model unchanged" + assert data["model"] == "claude-sonnet-4-5" + assert "previousModel" not in data + bridge.mark_dirty.assert_not_called() + + def test_context_not_initialized_returns_503(self, make_client): + """POST /model when bridge.context is None returns 503.""" + client, _ = make_client(has_context=False) + resp = client.post("/model", json={"model": "claude-opus-4"}) + + assert resp.status_code == 503 + assert "Context not initialized" in resp.json()["detail"] + + def test_locked_run_returns_422(self, make_client): + """POST /model while agent is mid-generation returns 422.""" + client, _ = make_client(lock_locked=True) + with patch.dict("os.environ", {"LLM_MODEL": "claude-sonnet-4-5"}): + resp = client.post("/model", json={"model": "claude-opus-4"}) + + assert resp.status_code == 422 + assert "Cannot switch model" in resp.json()["detail"] + + def test_updates_env_variable(self, make_client): + """POST /model updates the LLM_MODEL environment variable.""" + with patch.dict("os.environ", {"LLM_MODEL": "claude-sonnet-4-5"}): + client, _ = make_client(env_model="claude-sonnet-4-5") + resp = client.post("/model", json={"model": "claude-opus-4"}) + import os + + assert resp.status_code == 200 + assert os.environ["LLM_MODEL"] == "claude-opus-4" + + def test_emits_event_to_worker_queue(self, make_client): + """POST /model emits a model_switched event to the between-run queue.""" + with patch.dict("os.environ", {"LLM_MODEL": "claude-sonnet-4-5"}): + client, bridge = make_client(env_model="claude-sonnet-4-5") + resp = client.post("/model", json={"model": "claude-opus-4"}) + + assert resp.status_code == 200 + worker = bridge._session_manager.get_existing.return_value + assert not worker._between_run_queue.empty() + event = worker._between_run_queue.get_nowait() + assert event.name == "ambient:model_switched" + assert event.value["newModel"] == "claude-opus-4" + assert event.value["previousModel"] == "claude-sonnet-4-5"