Skip to content

Commit 4d46ee1

Browse files
committed
Quota helpers
1 parent 8e414a2 commit 4d46ee1

1 file changed

Lines changed: 67 additions & 0 deletions

File tree

src/utils/quota.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Quota handling helper functions."""
2+
3+
import psycopg2
4+
5+
from fastapi import HTTPException, status
6+
7+
from quota.quota_limiter import QuotaLimiter
8+
9+
from log import get_logger
10+
11+
logger = get_logger(__name__)
12+
13+
14+
def consume_tokens(
15+
quota_limiters: list[QuotaLimiter],
16+
user_id: str,
17+
input_tokens: int,
18+
output_tokens: int,
19+
) -> None:
20+
"""Consume tokens from cluster and/or user quotas."""
21+
# check if any quota limiter is configured
22+
for quota_limiter in quota_limiters:
23+
quota_limiter.consume_tokens(
24+
input_tokens=input_tokens,
25+
output_tokens=output_tokens,
26+
subject_id=user_id,
27+
)
28+
29+
30+
def check_tokens_available(quota_limiters: list[QuotaLimiter], user_id: str) -> None:
31+
"""Check if tokens are available for user."""
32+
try:
33+
for quota_limiter in quota_limiters:
34+
quota_limiter.ensure_available_quota(subject_id=user_id)
35+
except psycopg2.Error as pg_error:
36+
message = "Error communicating with quota database backend"
37+
logger.error(message)
38+
raise HTTPException(
39+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
40+
detail={
41+
"response": message,
42+
"cause": str(pg_error),
43+
},
44+
) from pg_error
45+
except Exception as quota_exceed_error:
46+
message = "The quota has been exceeded"
47+
logger.error(message)
48+
raise HTTPException(
49+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
50+
detail={
51+
"response": message,
52+
"cause": str(quota_exceed_error),
53+
},
54+
) from quota_exceed_error
55+
56+
57+
def get_available_quotas(
58+
quota_limiters: list[QuotaLimiter],
59+
user_id: str,
60+
) -> dict[str, int]:
61+
"""Get quota available from all quota limiters."""
62+
available_quotas: dict[str, int] = {}
63+
for quota_limiter in quota_limiters:
64+
name = quota_limiter.__class__.__name__
65+
available_quota = quota_limiter.available_quota(user_id)
66+
available_quotas[name] = available_quota
67+
return available_quotas

0 commit comments

Comments
 (0)