Skip to content

Commit 5411c84

Browse files
committed
Move limit middlewares from splunklib.ai.hooks to splunklib.ai.limits
Also changed the TokenLimitExceededException to accept an int, instead of a float.
1 parent cbdb570 commit 5411c84

7 files changed

Lines changed: 180 additions & 167 deletions

File tree

splunklib/ai/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ class. The default for that limit is suppressed automatically - the other defaul
958958
remain active:
959959

960960
```py
961-
from splunklib.ai.hooks import TokenLimitMiddleware, StepLimitMiddleware, TimeoutLimitMiddleware
961+
from splunklib.ai.limits import TokenLimitMiddleware, StepLimitMiddleware, TimeoutLimitMiddleware
962962

963963
async with Agent(
964964
...,

splunklib/ai/base_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pydantic import BaseModel
2222

2323
from splunklib.ai.conversation_store import ConversationStore
24-
from splunklib.ai.hooks import (
24+
from splunklib.ai.limits import (
2525
DEFAULT_STEP_LIMIT,
2626
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
2727
DEFAULT_TIMEOUT_SECONDS,

splunklib/ai/hooks.py

Lines changed: 0 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import inspect
22
from collections.abc import Awaitable, Callable
3-
from time import monotonic
43
from typing import Any, override
54

65
from splunklib.ai.messages import AgentResponse
@@ -12,44 +11,6 @@
1211
ModelRequest,
1312
ModelResponse,
1413
)
15-
from splunklib.ai.structured_output import StructuredOutputGenerationException
16-
17-
DEFAULT_TIMEOUT_SECONDS: float = 600.0
18-
DEFAULT_STEP_LIMIT: int = 100
19-
DEFAULT_TOKEN_LIMIT: int = 200_000
20-
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3
21-
22-
23-
class AgentStopException(Exception):
24-
"""Custom exception to indicate conversation stopping conditions."""
25-
26-
27-
class TokenLimitExceededException(AgentStopException):
28-
"""Raised by `Agent.invoke`, when token limit exceeds"""
29-
30-
def __init__(self, token_limit: float) -> None:
31-
super().__init__(f"Token limit of {token_limit} exceeded.")
32-
33-
34-
class StepsLimitExceededException(AgentStopException):
35-
"""Raised by `Agent.invoke`, when steps limit exceeds"""
36-
37-
def __init__(self, steps_limit: int) -> None:
38-
super().__init__(f"Steps limit of {steps_limit} exceeded.")
39-
40-
41-
class TimeoutExceededException(AgentStopException):
42-
"""Raised by `Agent.invoke`, when timeout exceeds"""
43-
44-
def __init__(self, timeout_seconds: float) -> None:
45-
super().__init__(f"Timed out after {timeout_seconds} seconds.")
46-
47-
48-
class StructuredOutputRetryLimitExceededException(AgentStopException):
49-
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""
50-
51-
def __init__(self, retry_count: int) -> None:
52-
super().__init__(f"Structured output retry limit of {retry_count} exceeded")
5314

5415

5516
def before_model(
@@ -132,123 +93,3 @@ async def agent_middleware(
13293
return handler_response
13394

13495
return _Middleware()
135-
136-
137-
class TokenLimitMiddleware(AgentMiddleware):
138-
"""Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
139-
140-
_limit: int
141-
142-
def __init__(self, limit: int) -> None:
143-
self._limit = limit
144-
145-
@override
146-
async def model_middleware(
147-
self,
148-
request: ModelRequest,
149-
handler: ModelMiddlewareHandler,
150-
) -> ModelResponse:
151-
if request.state.token_count >= self._limit:
152-
raise TokenLimitExceededException(token_limit=self._limit)
153-
return await handler(request)
154-
155-
156-
class StepLimitMiddleware(AgentMiddleware):
157-
"""Stops agent execution when the number of steps taken reaches the given limit."""
158-
159-
_limit: int
160-
161-
def __init__(self, limit: int) -> None:
162-
self._limit = limit
163-
164-
@override
165-
async def model_middleware(
166-
self,
167-
request: ModelRequest,
168-
handler: ModelMiddlewareHandler,
169-
) -> ModelResponse:
170-
if request.state.total_steps >= self._limit:
171-
raise StepsLimitExceededException(steps_limit=self._limit)
172-
return await handler(request)
173-
174-
175-
class TimeoutLimitMiddleware(AgentMiddleware):
176-
"""Stops agent execution when wall-clock time within an invoke exceeds the given seconds.
177-
178-
The deadline resets on every invoke call - it measures time from the start of
179-
each invocation, not from agent construction.
180-
181-
Do not share instances between agents.
182-
"""
183-
184-
_seconds: float
185-
_deadline_per_thread_id: dict[str, float]
186-
187-
def __init__(self, seconds: float) -> None:
188-
self._seconds = seconds
189-
self._deadline_per_thread_id = {}
190-
191-
@override
192-
async def agent_middleware(
193-
self,
194-
request: AgentRequest,
195-
handler: AgentMiddlewareHandler,
196-
) -> AgentResponse[Any | None]:
197-
try:
198-
# Agent loop starting.
199-
self._deadline_per_thread_id[request.thread_id] = (
200-
monotonic() + self._seconds
201-
)
202-
return await handler(request)
203-
finally:
204-
del self._deadline_per_thread_id[request.thread_id] # don't leak memory
205-
206-
@override
207-
async def model_middleware(
208-
self,
209-
request: ModelRequest,
210-
handler: ModelMiddlewareHandler,
211-
) -> ModelResponse:
212-
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
213-
raise TimeoutExceededException(timeout_seconds=self._seconds)
214-
return await handler(request)
215-
216-
217-
class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
218-
"""Stops agent execution when the agent exceeds structured output
219-
retry limit during a single agent loop invocation.
220-
"""
221-
222-
_limit: int
223-
_retries_per_thread_id: dict[str, int]
224-
225-
def __init__(self, limit: int) -> None:
226-
self._limit = limit
227-
self._retries_per_thread_id = {}
228-
229-
@override
230-
async def agent_middleware(
231-
self,
232-
request: AgentRequest,
233-
handler: AgentMiddlewareHandler,
234-
) -> AgentResponse[Any | None]:
235-
try:
236-
# Agent loop starting.
237-
self._retries_per_thread_id[request.thread_id] = 0
238-
return await handler(request)
239-
finally:
240-
del self._retries_per_thread_id[request.thread_id] # don't leak memory
241-
242-
@override
243-
async def model_middleware(
244-
self,
245-
request: ModelRequest,
246-
handler: ModelMiddlewareHandler,
247-
) -> ModelResponse:
248-
try:
249-
return await handler(request)
250-
except StructuredOutputGenerationException:
251-
self._retries_per_thread_id[request.state.thread_id] += 1
252-
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
253-
raise StructuredOutputRetryLimitExceededException(self._limit)
254-
raise # re-raise, to retry structured output generation

splunklib/ai/limits.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from time import monotonic
2+
from typing import Any, override
3+
4+
from splunklib.ai.messages import AgentResponse
5+
from splunklib.ai.middleware import (
6+
AgentMiddleware,
7+
AgentMiddlewareHandler,
8+
AgentRequest,
9+
ModelMiddlewareHandler,
10+
ModelRequest,
11+
ModelResponse,
12+
)
13+
from splunklib.ai.structured_output import StructuredOutputGenerationException
14+
15+
DEFAULT_TIMEOUT_SECONDS: float = 600.0
16+
DEFAULT_STEP_LIMIT: int = 100
17+
DEFAULT_TOKEN_LIMIT: int = 200_000
18+
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3
19+
20+
21+
class AgentStopException(Exception):
22+
"""Custom exception to indicate conversation stopping conditions."""
23+
24+
25+
class TokenLimitExceededException(AgentStopException):
26+
"""Raised by `Agent.invoke`, when token limit exceeds"""
27+
28+
def __init__(self, token_limit: int) -> None:
29+
super().__init__(f"Token limit of {token_limit} exceeded.")
30+
31+
32+
class StepsLimitExceededException(AgentStopException):
33+
"""Raised by `Agent.invoke`, when steps limit exceeds"""
34+
35+
def __init__(self, steps_limit: int) -> None:
36+
super().__init__(f"Steps limit of {steps_limit} exceeded.")
37+
38+
39+
class TimeoutExceededException(AgentStopException):
40+
"""Raised by `Agent.invoke`, when timeout exceeds"""
41+
42+
def __init__(self, timeout_seconds: float) -> None:
43+
super().__init__(f"Timed out after {timeout_seconds} seconds.")
44+
45+
46+
class StructuredOutputRetryLimitExceededException(AgentStopException):
47+
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""
48+
49+
def __init__(self, retry_count: int) -> None:
50+
super().__init__(f"Structured output retry limit of {retry_count} exceeded")
51+
52+
53+
class TokenLimitMiddleware(AgentMiddleware):
54+
"""Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
55+
56+
_limit: int
57+
58+
def __init__(self, limit: int) -> None:
59+
self._limit = limit
60+
61+
@override
62+
async def model_middleware(
63+
self,
64+
request: ModelRequest,
65+
handler: ModelMiddlewareHandler,
66+
) -> ModelResponse:
67+
if request.state.token_count >= self._limit:
68+
raise TokenLimitExceededException(token_limit=self._limit)
69+
return await handler(request)
70+
71+
72+
class StepLimitMiddleware(AgentMiddleware):
73+
"""Stops agent execution when the number of steps taken reaches the given limit."""
74+
75+
_limit: int
76+
77+
def __init__(self, limit: int) -> None:
78+
self._limit = limit
79+
80+
@override
81+
async def model_middleware(
82+
self,
83+
request: ModelRequest,
84+
handler: ModelMiddlewareHandler,
85+
) -> ModelResponse:
86+
if request.state.total_steps >= self._limit:
87+
raise StepsLimitExceededException(steps_limit=self._limit)
88+
return await handler(request)
89+
90+
91+
class TimeoutLimitMiddleware(AgentMiddleware):
92+
"""Stops agent execution when wall-clock time within an invoke exceeds the given seconds.
93+
94+
The deadline resets on every invoke call - it measures time from the start of
95+
each invocation, not from agent construction.
96+
97+
Do not share instances between agents.
98+
"""
99+
100+
_seconds: float
101+
_deadline_per_thread_id: dict[str, float]
102+
103+
def __init__(self, seconds: float) -> None:
104+
self._seconds = seconds
105+
self._deadline_per_thread_id = {}
106+
107+
@override
108+
async def agent_middleware(
109+
self,
110+
request: AgentRequest,
111+
handler: AgentMiddlewareHandler,
112+
) -> AgentResponse[Any | None]:
113+
try:
114+
# Agent loop starting.
115+
self._deadline_per_thread_id[request.thread_id] = (
116+
monotonic() + self._seconds
117+
)
118+
return await handler(request)
119+
finally:
120+
del self._deadline_per_thread_id[request.thread_id] # don't leak memory
121+
122+
@override
123+
async def model_middleware(
124+
self,
125+
request: ModelRequest,
126+
handler: ModelMiddlewareHandler,
127+
) -> ModelResponse:
128+
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
129+
raise TimeoutExceededException(timeout_seconds=self._seconds)
130+
return await handler(request)
131+
132+
133+
class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
134+
"""Stops agent execution when the agent exceeds structured output
135+
retry limit during a single agent loop invocation.
136+
"""
137+
138+
_limit: int
139+
_retries_per_thread_id: dict[str, int]
140+
141+
def __init__(self, limit: int) -> None:
142+
self._limit = limit
143+
self._retries_per_thread_id = {}
144+
145+
@override
146+
async def agent_middleware(
147+
self,
148+
request: AgentRequest,
149+
handler: AgentMiddlewareHandler,
150+
) -> AgentResponse[Any | None]:
151+
try:
152+
# Agent loop starting.
153+
self._retries_per_thread_id[request.thread_id] = 0
154+
return await handler(request)
155+
finally:
156+
del self._retries_per_thread_id[request.thread_id] # don't leak memory
157+
158+
@override
159+
async def model_middleware(
160+
self,
161+
request: ModelRequest,
162+
handler: ModelMiddlewareHandler,
163+
) -> ModelResponse:
164+
try:
165+
return await handler(request)
166+
except StructuredOutputGenerationException:
167+
self._retries_per_thread_id[request.state.thread_id] += 1
168+
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
169+
raise StructuredOutputRetryLimitExceededException(self._limit)
170+
raise # re-raise, to retry structured output generation

tests/integration/ai/test_hooks.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@
1818
from splunklib.ai import Agent
1919
from splunklib.ai.conversation_store import InMemoryStore
2020
from splunklib.ai.hooks import (
21+
after_agent,
22+
after_model,
23+
before_agent,
24+
before_model,
25+
)
26+
from splunklib.ai.limits import (
2127
StepLimitMiddleware,
2228
StepsLimitExceededException,
2329
TimeoutExceededException,
2430
TimeoutLimitMiddleware,
2531
TokenLimitExceededException,
2632
TokenLimitMiddleware,
27-
after_agent,
28-
after_model,
29-
before_agent,
30-
before_model,
3133
)
3234
from splunklib.ai.messages import AgentResponse, AIMessage, HumanMessage
3335
from splunklib.ai.middleware import (

tests/integration/ai/test_structured_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pydantic.dataclasses import dataclass
2222

2323
from splunklib.ai import Agent
24-
from splunklib.ai.hooks import (
24+
from splunklib.ai.limits import (
2525
StructuredOutputRetryLimitExceededException,
2626
StructuredOutputRetryLimitMiddleware,
2727
)

0 commit comments

Comments
 (0)