Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
390 changes: 390 additions & 0 deletions packages/ai-server/scripts/cost_probe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,390 @@
#!/usr/bin/env python3
# pyright: reportMissingImports=false
"""Single raw_post 의 전체 Gemini 파이프라인 비용을 실측한다.

목적: cost-tracking 시스템 (DB pricing + per-call recorder) 을 만들기 전에
"한 raw_post 처리에 실제로 token / image / grounding query 가 얼마나
드는지" 를 ad-hoc 으로 측정. 결과로 plan 의 비용 추정치를 보강.

원리:
- prod assets DB 에서 image_url 있는 raw_post 1건 random pick.
- prod 파이프라인이 만드는 8가지 호출 타입을 1회씩 (item 단위는 N=3 cap):
1. hero_reframe gemini-2.5-flash-image image out
2. subject gemini-2.5-flash image+text
3. items gemini-2.5-pro image+text → N items
4. spots gemini-2.5-flash hero+items
5. thumbnail × N gemini-2.5-flash-image image out
6. url_grounded × N gemini-2.5-flash + googleSearch
7. url_filter × N gemini-2.5-flash + image
- 각 호출의 usage_metadata 캡처 (prompt_token_count / candidates_token_count /
cached_content_token_count, image 모델은 candidates_token_count 가 image
bytes 분 ≈ 1290).
- 끝에 합계 + Gemini 공시 단가 (스크립트 상단 상수) 로 USD 환산.

단가 상수는 *프로브 한정* (production code 는 plan 대로 DB SOT).
출처는 `https://ai.google.dev/pricing` (2026-05 기준 manual 입력).

Usage:
cd packages/ai-server
uv run python scripts/cost_probe.py
uv run python scripts/cost_probe.py --id <raw_post_uuid>
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional

import asyncpg
import httpx
from google import genai
from google.genai import types as genai_types


_THIS = Path(__file__).resolve()
_AI_SERVER_ROOT = _THIS.parent.parent
sys.path.insert(0, str(_AI_SERVER_ROOT))

from src.services.raw_posts.processors.prompts import ( # noqa: E402
HERO_REFRAME_PROMPT,
ITEM_THUMBNAIL_PROMPT,
ITEMS_PROMPT,
SPOTS_PROMPT,
SUBJECT_PROMPT,
)
from src.services.raw_posts.processors.items_parser import _ItemsResponse # noqa: E402
from src.services.raw_posts.processors.spots_parser import _SpotsResponse # noqa: E402
from src.services.raw_posts.processors.subject_parser import _SubjectDraft # noqa: E402


logger = logging.getLogger("cost-probe")


# === Gemini 단가 (USD / unit) — 2026-05 ai.google.dev/pricing 기준 manual ===
# 프로브 스크립트 한정. 운영 시스템은 DB pricing 테이블 SOT.
PRICING = {
"gemini-2.5-pro": {
"input_token": 1.25 / 1_000_000, # ≤200k context
"output_token": 10.0 / 1_000_000,
"cached_input_token": 0.3125 / 1_000_000,
},
"gemini-2.5-flash": {
"input_token": 0.30 / 1_000_000,
"output_token": 2.50 / 1_000_000,
"cached_input_token": 0.075 / 1_000_000,
},
"gemini-2.5-flash-image": {
"image_output": 0.039, # per image (1024x1024)
"input_token": 0.30 / 1_000_000, # image input still tokenized
},
"grounding_query": 35.0 / 1_000, # $35 / 1000 google search queries
}

ITEM_CAP = 3 # 실제 prod 평균 ≈ 5, probe 는 3 으로 cap (비용 ↓)


@dataclass
class CallLog:
step: str
model: str
ok: bool
prompt_tokens: int = 0
completion_tokens: int = 0
cached_tokens: int = 0
image_output: int = 0
grounding_queries: int = 0
latency_ms: int = 0
err: Optional[str] = None
est_cost_usd: float = 0.0


calls: list[CallLog] = []


def _extract_usage(resp: Any) -> tuple[int, int, int]:
um = getattr(resp, "usage_metadata", None)
if not um:
return (0, 0, 0)
return (
getattr(um, "prompt_token_count", 0) or 0,
getattr(um, "candidates_token_count", 0) or 0,
getattr(um, "cached_content_token_count", 0) or 0,
)


def _price_text(model: str, prompt: int, completion: int, cached: int) -> float:
p = PRICING.get(model, {})
cost = 0.0
cost += (prompt - (cached or 0)) * p.get("input_token", 0)
cost += completion * p.get("output_token", 0)
cost += (cached or 0) * p.get("cached_input_token", p.get("input_token", 0))
return cost


async def _download(url: str) -> tuple[bytes, str]:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as c:
r = await c.get(url)
r.raise_for_status()
ct = (r.headers.get("content-type") or "image/jpeg").split(";")[0].strip()
return r.content, ct


async def _pick_sample(db_url: str, raw_id: Optional[str]) -> dict:
conn = await asyncpg.connect(db_url)
try:
if raw_id:
row = await conn.fetchrow(
"SELECT id::text AS id, image_url, caption, parse_result FROM public.raw_posts WHERE id=$1::uuid",
raw_id,
)
else:
row = await conn.fetchrow(
"""
SELECT id::text AS id, image_url, caption, parse_result
FROM public.raw_posts
WHERE image_url IS NOT NULL AND image_url != ''
AND status = 'COMPLETED'
AND parse_result IS NOT NULL
AND jsonb_typeof(parse_result -> 'items') = 'array'
AND jsonb_array_length(parse_result -> 'items') >= 3
ORDER BY random() LIMIT 1
"""
)
if not row:
raise SystemExit("no eligible raw_post found")
return dict(row)
finally:
await conn.close()


async def _call_text(
client: genai.Client,
*,
step: str,
model: str,
contents: list,
response_schema: Optional[Any] = None,
) -> Any:
cfg = genai_types.GenerateContentConfig(temperature=0.1)
if response_schema is not None:
cfg = genai_types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=response_schema,
temperature=0.1,
)
t0 = time.monotonic()
log = CallLog(step=step, model=model, ok=False)
try:
resp = await client.aio.models.generate_content(model=model, contents=contents, config=cfg)
log.ok = True
p, c, ca = _extract_usage(resp)
log.prompt_tokens, log.completion_tokens, log.cached_tokens = p, c, ca
log.est_cost_usd = _price_text(model, p, c, ca)
return resp
except Exception as exc: # noqa: BLE001
log.err = f"{type(exc).__name__}: {exc}"
raise
finally:
log.latency_ms = int((time.monotonic() - t0) * 1000)
calls.append(log)


async def _call_image(
client: genai.Client,
*,
step: str,
image_bytes: bytes,
content_type: str,
prompt: str,
aspect_ratio: str,
image_size: str,
) -> Any:
cfg = genai_types.GenerateContentConfig(
response_modalities=["IMAGE"],
image_config=genai_types.ImageConfig(aspect_ratio=aspect_ratio, image_size=image_size),
)
t0 = time.monotonic()
log = CallLog(step=step, model="gemini-2.5-flash-image", ok=False)
try:
resp = await client.aio.models.generate_content(
model="gemini-2.5-flash-image",
contents=[
genai_types.Part.from_bytes(data=image_bytes, mime_type=content_type),
prompt,
],
config=cfg,
)
log.ok = True
p, c, ca = _extract_usage(resp)
log.prompt_tokens = p
log.image_output = 1
log.est_cost_usd = (
PRICING["gemini-2.5-flash-image"]["image_output"]
+ p * PRICING["gemini-2.5-flash-image"]["input_token"]
)
return resp
except Exception as exc: # noqa: BLE001
log.err = f"{type(exc).__name__}: {exc}"
raise
finally:
log.latency_ms = int((time.monotonic() - t0) * 1000)
calls.append(log)


async def _call_grounded(
api_key: str, *, step: str, brand: str, title: str, model: str
) -> dict:
prompt = f"Find official PDP URLs for fashion item: brand={brand!r}, product={title!r}. Return top 5 URLs."
payload = {
"contents": [{"role": "user", "parts": [{"text": prompt}]}],
"tools": [{"googleSearch": {}}],
"generationConfig": {"temperature": 0.2},
}
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
t0 = time.monotonic()
log = CallLog(step=step, model=model, ok=False, grounding_queries=1)
try:
async with httpx.AsyncClient(timeout=60.0) as c:
r = await c.post(url, params={"key": api_key}, json=payload)
r.raise_for_status()
data = r.json()
log.ok = True
um = (data or {}).get("usageMetadata") or {}
p = um.get("promptTokenCount", 0) or 0
cc = um.get("candidatesTokenCount", 0) or 0
log.prompt_tokens, log.completion_tokens = p, cc
log.est_cost_usd = (
_price_text(model, p, cc, 0) + PRICING["grounding_query"] * 1
)
return data
except Exception as exc: # noqa: BLE001
log.err = f"{type(exc).__name__}: {exc}"
raise
finally:
log.latency_ms = int((time.monotonic() - t0) * 1000)
calls.append(log)


async def main() -> int:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
parser = argparse.ArgumentParser()
parser.add_argument("--id", type=str, default=None)
args = parser.parse_args()

db_url = os.environ.get("ASSETS_DATABASE_URL")
api_key = os.environ.get("GEMINI_API_KEY")
if not db_url or not api_key:
raise SystemExit("ASSETS_DATABASE_URL + GEMINI_API_KEY required (source .env.backend.prod)")

row = await _pick_sample(db_url, args.id)
logger.info("sample raw_post: id=%s image_url=%s", row["id"], row["image_url"])
image_bytes, content_type = await _download(row["image_url"])
logger.info("image: %d bytes, ct=%s", len(image_bytes), content_type)

client = genai.Client(api_key=api_key)
img_part = genai_types.Part.from_bytes(data=image_bytes, mime_type=content_type)

# === 1. hero_reframe (flash-image) ===
logger.info("[1/7] hero_reframe …")
await _call_image(
client, step="hero_reframe", image_bytes=image_bytes, content_type=content_type,
prompt=HERO_REFRAME_PROMPT, aspect_ratio="4:5", image_size="2K",
)

# === 2. subject (flash) ===
logger.info("[2/7] subject …")
caption_text = (row.get("caption") or "").strip()
subj_contents = [img_part, SUBJECT_PROMPT + (f"\n\nCaption: {caption_text[:300]}" if caption_text else "")]
await _call_text(client, step="subject", model="gemini-2.5-flash",
contents=subj_contents, response_schema=_SubjectDraft)

# === 3. items (pro) ===
logger.info("[3/7] items …")
items_resp = await _call_text(client, step="items", model="gemini-2.5-pro",
contents=[img_part, ITEMS_PROMPT],
response_schema=_ItemsResponse)
# raw image_url 은 합성 전이라 items parser 가 0 으로 돌아오는 게 정상.
# 실제 N 은 prod parse_result.items 에서 가져와 per-item 호출 시뮬레이션.
cached: list[dict] = []
pr = row.get("parse_result")
if pr:
try:
pr_obj = pr if isinstance(pr, dict) else json.loads(pr)
cached = (pr_obj or {}).get("items", []) or []
except Exception: # noqa: BLE001
pass
n_items = min(len(cached), ITEM_CAP)
logger.info("items in cached parse_result: %d (probing %d)", len(cached), n_items)
items = cached

# === 4. spots (flash) — image + items text ===
logger.info("[4/7] spots …")
items_text = json.dumps([{"brand": (it or {}).get("brand"), "product": (it or {}).get("product")} for it in items[:n_items]])
await _call_text(client, step="spots", model="gemini-2.5-flash",
contents=[img_part, SPOTS_PROMPT + "\n\nitems=" + items_text],
response_schema=_SpotsResponse)

# === 5. thumbnail × N (flash-image) ===
for i in range(n_items):
logger.info("[5/7] thumbnail %d/%d …", i + 1, n_items)
await _call_image(
client, step=f"thumbnail#{i+1}", image_bytes=image_bytes, content_type=content_type,
prompt=ITEM_THUMBNAIL_PROMPT, aspect_ratio="1:1", image_size="1K",
)

# === 6. url_search.grounded × N (flash + googleSearch) ===
for i in range(n_items):
it = items[i] or {}
brand = (it.get("brand") or "unknown brand")[:80]
product = (it.get("product") or "unknown product")[:120]
logger.info("[6/7] url_grounded %d/%d (%s · %s) …", i + 1, n_items, brand, product)
try:
await _call_grounded(api_key, step=f"url_grounded#{i+1}", brand=brand, title=product,
model="gemini-2.5-flash")
except Exception as exc: # noqa: BLE001
logger.warning("url_grounded failed: %s", exc)

# === 7. url_search.filter × N (flash + thumbnail image) ===
for i in range(n_items):
logger.info("[7/7] url_filter %d/%d …", i + 1, n_items)
await _call_text(
client, step=f"url_filter#{i+1}", model="gemini-2.5-flash",
contents=[img_part, "Evaluate top product URL candidates. Return JSON with best_url, confidence, domain_class."],
)

# === 결과 출력 ===
print()
print(f"{'step':<22}{'model':<26}{'ok':<4}{'in tok':>8}{'out tok':>9}{'img':>5}{'grnd':>6}{'lat ms':>9}{'$':>10}")
print("-" * 99)
total = {"in": 0, "out": 0, "img": 0, "grnd": 0, "ms": 0, "usd": 0.0}
for c in calls:
print(f"{c.step:<22}{c.model:<26}{('✓' if c.ok else '✗'):<4}"
f"{c.prompt_tokens:>8}{c.completion_tokens:>9}{c.image_output:>5}"
f"{c.grounding_queries:>6}{c.latency_ms:>9}{c.est_cost_usd:>10.5f}")
total["in"] += c.prompt_tokens; total["out"] += c.completion_tokens
total["img"] += c.image_output; total["grnd"] += c.grounding_queries
total["ms"] += c.latency_ms; total["usd"] += c.est_cost_usd
print("-" * 99)
print(f"{'TOTAL':<22}{'':<26}{'':<4}{total['in']:>8}{total['out']:>9}{total['img']:>5}"
f"{total['grnd']:>6}{total['ms']:>9}{total['usd']:>10.5f}")
print()
print(f"raw_post id = {row['id']}")
print(f"items detected (prod 평균 ≈ 5) = {len(items)}, probe 사용 = {n_items}")
print(f"실측 ${total['usd']:.4f} / 1 raw_post (ITEM_CAP={ITEM_CAP})")
if n_items > 0:
per_item = total['usd'] - sum(c.est_cost_usd for c in calls if not c.step.startswith(('thumbnail', 'url_')))
prod_est = (total['usd'] - per_item) + per_item / n_items * 5
print(f"prod 추정 (5 items): ${prod_est:.4f}")
return 0


if __name__ == "__main__":
sys.exit(asyncio.run(main()))
Loading
Loading