Skip to content

Commit e4b7a65

Browse files
jmc-wanderclaude
andcommitted
Add KubernetesLoRABackend for real GPU fine-tuning on GKE
Replaces LocalNoOpBackend with a production-ready backend that submits K8s Jobs to spot GPU nodes for QLoRA fine-tuning via Unsloth. Training data and GGUF artifacts pass through GCS. The GGUF is downloaded and registered with Ollama so the existing validate/promote pipeline works. Production hardening: structured logging with run_id correlation, GCS retry with exponential backoff, blob existence + disk space checks, K8s resource limits (CPU/memory/GPU), pod event diagnostics on failure, Ollama registration retry, pipeline concurrency guard, /tmp cleanup in trainer container, pinned Docker image deps, and configurable backoff. New files: - src/apprentice/kubernetes_lora_backend.py (backend + config) - docker/trainer/Dockerfile (CUDA + Unsloth + llama.cpp) - docker/trainer/train.py (self-contained training script) - tests/test_kubernetes_lora_backend.py (57 tests) Modified: - config_loader: kubernetes_lora enum + GKE fields + cross-field validation - factory: backend selection + FineTuningOrchestrator wiring - serve.py: asyncio.to_thread + orchestrator + concurrency guard - fine_tuning_orchestrator: LocalNoOpOrchestratorConfig + K8s stub - pyproject.toml: gke optional deps - examples/apprentice.yaml: kubernetes_lora config example Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 77ce8d1 commit e4b7a65

10 files changed

Lines changed: 2173 additions & 21 deletions

File tree

docker/trainer/Dockerfile

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
FROM nvidia/cuda:12.1.1-devel-ubuntu22.04
2+
3+
# Avoid interactive prompts during package installation
4+
ENV DEBIAN_FRONTEND=noninteractive
5+
6+
# Install Python 3.11 and system dependencies
7+
RUN apt-get update && apt-get install -y --no-install-recommends \
8+
python3.11 \
9+
python3.11-venv \
10+
python3-pip \
11+
git \
12+
cmake \
13+
build-essential \
14+
curl \
15+
&& rm -rf /var/lib/apt/lists/*
16+
17+
# Make python3.11 the default
18+
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 \
19+
&& update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
20+
21+
# Install Python dependencies — pinned versions for reproducibility
22+
RUN pip install --no-cache-dir \
23+
"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@2025.3.19" \
24+
"xformers==0.0.29.post3" \
25+
"trl>=0.7,<0.13" \
26+
"peft>=0.13,<0.15" \
27+
"accelerate>=0.34,<1.0" \
28+
"bitsandbytes>=0.44,<0.45" \
29+
"google-cloud-storage>=2.10,<3.0" \
30+
"sentencepiece>=0.2,<0.3" \
31+
"protobuf>=4.25,<6.0" \
32+
"datasets>=3.0,<4.0"
33+
34+
# Build llama.cpp for GGUF conversion — pinned to a release tag
35+
ARG LLAMA_CPP_VERSION=b4722
36+
RUN git clone --branch ${LLAMA_CPP_VERSION} --depth 1 \
37+
https://github.com/ggerganov/llama.cpp /opt/llama.cpp \
38+
&& cd /opt/llama.cpp \
39+
&& cmake -B build -DGGML_CUDA=ON \
40+
&& cmake --build build --config Release -j$(nproc) \
41+
&& cp build/bin/llama-quantize /usr/local/bin/llama-quantize
42+
43+
ENV LLAMA_CPP_PATH=/usr/local/bin/llama-quantize
44+
45+
WORKDIR /app
46+
COPY train.py /app/train.py
47+
48+
ENTRYPOINT ["python", "train.py"]

docker/trainer/train.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""Apprentice Trainer — runs inside the K8s training Job container.
2+
3+
Performs QLoRA fine-tuning via Unsloth, merges LoRA adapters, converts to
4+
GGUF, and uploads the artifact + metrics to GCS.
5+
6+
All configuration is passed via environment variables set by the
7+
KubernetesLoRABackend when creating the K8s Job.
8+
9+
Production features:
10+
- GCS operations with exponential backoff retry
11+
- /tmp cleanup before and after training
12+
- Validation of required tools (convert script, llama-quantize)
13+
- Structured error exit codes (1=general, 2=GCS, 3=training, 4=conversion)
14+
"""
15+
16+
import json
17+
import os
18+
import shutil
19+
import subprocess
20+
import sys
21+
import time
22+
from pathlib import Path
23+
24+
25+
# ── Exit codes ──────────────────────────────────────────────────────────────
26+
EXIT_GCS_ERROR = 2
27+
EXIT_TRAINING_ERROR = 3
28+
EXIT_CONVERSION_ERROR = 4
29+
30+
31+
# ── Retry helper ────────────────────────────────────────────────────────────
32+
33+
34+
def _retry(fn, retries=3, base_delay=2.0, operation="operation"):
35+
"""Retry a callable with exponential backoff. Returns the result or re-raises."""
36+
last_exc = None
37+
for attempt in range(1, retries + 1):
38+
try:
39+
return fn()
40+
except Exception as e:
41+
last_exc = e
42+
if attempt < retries:
43+
delay = base_delay * (2 ** (attempt - 1))
44+
print(
45+
f"[trainer] {operation} failed (attempt {attempt}/{retries}): {e}. "
46+
f"Retrying in {delay:.1f}s...",
47+
file=sys.stderr,
48+
)
49+
time.sleep(delay)
50+
else:
51+
print(
52+
f"[trainer] {operation} failed after {retries} attempts: {e}",
53+
file=sys.stderr,
54+
)
55+
raise last_exc
56+
57+
58+
# ── Cleanup helper ──────────────────────────────────────────────────────────
59+
60+
61+
def _cleanup_work_dir(work_dir: Path) -> None:
62+
"""Remove the work directory if it exists."""
63+
if work_dir.exists():
64+
try:
65+
shutil.rmtree(str(work_dir))
66+
print(f"[trainer] Cleaned up {work_dir}")
67+
except Exception as e:
68+
print(f"[trainer] Warning: cleanup of {work_dir} failed: {e}", file=sys.stderr)
69+
70+
71+
# ── Main ────────────────────────────────────────────────────────────────────
72+
73+
74+
def main() -> None:
75+
# ── Parse environment variables ──────────────────────────────────────
76+
required_vars = ["GCS_BUCKET", "GCS_PREFIX", "RUN_ID"]
77+
for var in required_vars:
78+
if var not in os.environ:
79+
print(f"[trainer] FATAL: Required environment variable {var} is not set", file=sys.stderr)
80+
sys.exit(1)
81+
82+
gcs_bucket = os.environ["GCS_BUCKET"]
83+
gcs_prefix = os.environ["GCS_PREFIX"]
84+
run_id = os.environ["RUN_ID"]
85+
base_model = os.environ.get("BASE_MODEL", "unsloth/llama-3.1-8b-bnb-4bit")
86+
quantization_type = os.environ.get("QUANTIZATION_TYPE", "Q4_K_M")
87+
max_seq_length = int(os.environ.get("MAX_SEQ_LENGTH", "2048"))
88+
lora_rank = int(os.environ.get("LORA_RANK", "16"))
89+
learning_rate = float(os.environ.get("LEARNING_RATE", "2e-4"))
90+
num_epochs = int(os.environ.get("NUM_EPOCHS", "3"))
91+
92+
work_dir = Path("/tmp/training")
93+
94+
# Clean up any leftover state from a previous run (spot preemption retry)
95+
_cleanup_work_dir(work_dir)
96+
work_dir.mkdir(parents=True, exist_ok=True)
97+
98+
data_path = work_dir / "data.jsonl"
99+
model_dir = work_dir / "model"
100+
merged_dir = work_dir / "merged"
101+
gguf_path = work_dir / "model.gguf"
102+
metrics_path = work_dir / "metrics.json"
103+
104+
gcs_data_blob = f"{gcs_prefix}/{run_id}/data.jsonl"
105+
gcs_gguf_blob = f"{gcs_prefix}/{run_id}/model.gguf"
106+
gcs_metrics_blob = f"{gcs_prefix}/{run_id}/metrics.json"
107+
108+
print(f"[trainer] Starting run {run_id}")
109+
print(f"[trainer] Base model: {base_model}")
110+
print(f"[trainer] Quantization: {quantization_type}")
111+
print(f"[trainer] LoRA rank: {lora_rank}, LR: {learning_rate}, Epochs: {num_epochs}")
112+
113+
# ── Validate required tools ──────────────────────────────────────────
114+
convert_script = Path("/opt/llama.cpp/convert_hf_to_gguf.py")
115+
llama_quantize = os.environ.get("LLAMA_CPP_PATH", "/usr/local/bin/llama-quantize")
116+
117+
if not convert_script.exists():
118+
print(f"[trainer] FATAL: Convert script not found at {convert_script}", file=sys.stderr)
119+
sys.exit(EXIT_CONVERSION_ERROR)
120+
121+
if not Path(llama_quantize).exists():
122+
print(f"[trainer] FATAL: llama-quantize not found at {llama_quantize}", file=sys.stderr)
123+
sys.exit(EXIT_CONVERSION_ERROR)
124+
125+
train_start = time.time()
126+
127+
try:
128+
# ── 1. Download training data from GCS ───────────────────────────
129+
print("[trainer] Downloading training data from GCS...")
130+
from google.cloud import storage
131+
132+
gcs_client = _retry(
133+
lambda: storage.Client(),
134+
retries=3,
135+
operation="GCS client init",
136+
)
137+
bucket = gcs_client.bucket(gcs_bucket)
138+
blob = bucket.blob(gcs_data_blob)
139+
140+
_retry(
141+
lambda: blob.download_to_filename(str(data_path)),
142+
retries=3,
143+
base_delay=5.0,
144+
operation=f"GCS download gs://{gcs_bucket}/{gcs_data_blob}",
145+
)
146+
147+
# Parse JSONL into dataset
148+
examples = []
149+
with open(data_path) as f:
150+
for line in f:
151+
line = line.strip()
152+
if line:
153+
examples.append(json.loads(line))
154+
155+
if not examples:
156+
print("[trainer] FATAL: Training data file is empty", file=sys.stderr)
157+
sys.exit(EXIT_GCS_ERROR)
158+
159+
print(f"[trainer] Loaded {len(examples)} training examples")
160+
161+
# ── 2. Load model and tokenizer ──────────────────────────────────
162+
print("[trainer] Loading model and tokenizer...")
163+
from unsloth import FastLanguageModel
164+
165+
model, tokenizer = FastLanguageModel.from_pretrained(
166+
model_name=base_model,
167+
max_seq_length=max_seq_length,
168+
load_in_4bit=True,
169+
)
170+
171+
# ── 3. Apply LoRA adapters ───────────────────────────────────────
172+
print("[trainer] Applying LoRA adapters...")
173+
model = FastLanguageModel.get_peft_model(
174+
model,
175+
r=lora_rank,
176+
lora_alpha=lora_rank * 2,
177+
lora_dropout=0,
178+
target_modules=[
179+
"q_proj", "k_proj", "v_proj", "o_proj",
180+
"gate_proj", "up_proj", "down_proj",
181+
],
182+
bias="none",
183+
use_gradient_checkpointing="unsloth",
184+
)
185+
186+
# ── 4. Prepare dataset ───────────────────────────────────────────
187+
from datasets import Dataset
188+
189+
def format_chat(example: dict) -> dict:
190+
text = tokenizer.apply_chat_template(
191+
example["messages"], tokenize=False, add_generation_prompt=False,
192+
)
193+
return {"text": text}
194+
195+
dataset = Dataset.from_list(examples)
196+
dataset = dataset.map(format_chat)
197+
198+
# ── 5. Train with SFTTrainer ─────────────────────────────────────
199+
print("[trainer] Starting training...")
200+
from trl import SFTTrainer
201+
from transformers import TrainingArguments
202+
203+
training_args = TrainingArguments(
204+
output_dir=str(model_dir),
205+
per_device_train_batch_size=2,
206+
gradient_accumulation_steps=4,
207+
warmup_steps=5,
208+
num_train_epochs=num_epochs,
209+
learning_rate=learning_rate,
210+
fp16=True,
211+
logging_steps=10,
212+
save_strategy="no",
213+
optim="adamw_8bit",
214+
)
215+
216+
trainer = SFTTrainer(
217+
model=model,
218+
tokenizer=tokenizer,
219+
train_dataset=dataset,
220+
args=training_args,
221+
dataset_text_field="text",
222+
max_seq_length=max_seq_length,
223+
packing=False,
224+
)
225+
226+
train_result = trainer.train()
227+
train_loss = train_result.training_loss
228+
train_steps = train_result.global_step
229+
print(f"[trainer] Training complete. Loss: {train_loss:.4f}, Steps: {train_steps}")
230+
231+
# ── 6. Merge LoRA adapters ───────────────────────────────────────
232+
print("[trainer] Merging LoRA adapters...")
233+
merged_dir.mkdir(parents=True, exist_ok=True)
234+
model.save_pretrained_merged(str(merged_dir), tokenizer, save_method="merged_16bit")
235+
236+
# ── 7. Convert to GGUF ───────────────────────────────────────────
237+
print(f"[trainer] Converting to GGUF ({quantization_type})...")
238+
239+
# First convert HF model to f16 GGUF
240+
f16_gguf = work_dir / "model-f16.gguf"
241+
convert_result = subprocess.run(
242+
["python", str(convert_script), str(merged_dir),
243+
"--outfile", str(f16_gguf), "--outtype", "f16"],
244+
check=False,
245+
capture_output=True,
246+
text=True,
247+
)
248+
if convert_result.returncode != 0:
249+
print(f"[trainer] FATAL: HF-to-GGUF conversion failed:\n{convert_result.stderr}", file=sys.stderr)
250+
sys.exit(EXIT_CONVERSION_ERROR)
251+
252+
# Then quantize
253+
quant_result = subprocess.run(
254+
[llama_quantize, str(f16_gguf), str(gguf_path), quantization_type],
255+
check=False,
256+
capture_output=True,
257+
text=True,
258+
)
259+
if quant_result.returncode != 0:
260+
print(f"[trainer] FATAL: Quantization failed:\n{quant_result.stderr}", file=sys.stderr)
261+
sys.exit(EXIT_CONVERSION_ERROR)
262+
263+
gguf_size = gguf_path.stat().st_size
264+
print(f"[trainer] GGUF created: {gguf_size / 1024 / 1024:.1f} MB")
265+
266+
# Clean up intermediate files to free disk space before upload
267+
if f16_gguf.exists():
268+
f16_gguf.unlink()
269+
if merged_dir.exists():
270+
shutil.rmtree(str(merged_dir))
271+
print("[trainer] Cleaned up intermediate files")
272+
273+
# ── 8. Write metrics ─────────────────────────────────────────────
274+
train_duration = time.time() - train_start
275+
metrics = {
276+
"final_loss": train_loss,
277+
"num_steps": train_steps,
278+
"num_epochs_completed": float(num_epochs),
279+
"training_duration_seconds": train_duration,
280+
"additional_metrics": {
281+
"gguf_size_bytes": gguf_size,
282+
"quantization_type": quantization_type,
283+
"lora_rank": lora_rank,
284+
"learning_rate": learning_rate,
285+
"num_examples": len(examples),
286+
"base_model": base_model,
287+
},
288+
}
289+
with open(metrics_path, "w") as f:
290+
json.dump(metrics, f, indent=2)
291+
292+
# ── 9. Upload GGUF + metrics to GCS ──────────────────────────────
293+
print("[trainer] Uploading GGUF to GCS...")
294+
gguf_blob = bucket.blob(gcs_gguf_blob)
295+
_retry(
296+
lambda: gguf_blob.upload_from_filename(str(gguf_path)),
297+
retries=3,
298+
base_delay=10.0,
299+
operation=f"GCS upload GGUF ({gguf_size / 1024 / 1024:.1f} MB)",
300+
)
301+
302+
print("[trainer] Uploading metrics to GCS...")
303+
metrics_blob = bucket.blob(gcs_metrics_blob)
304+
_retry(
305+
lambda: metrics_blob.upload_from_filename(str(metrics_path)),
306+
retries=3,
307+
base_delay=5.0,
308+
operation="GCS upload metrics",
309+
)
310+
311+
print(f"[trainer] Done. Artifacts at gs://{gcs_bucket}/{gcs_prefix}/{run_id}/")
312+
print(f"[trainer] Total duration: {train_duration:.0f}s")
313+
314+
finally:
315+
# Always clean up /tmp to avoid disk pressure on shared nodes
316+
_cleanup_work_dir(work_dir)
317+
318+
319+
if __name__ == "__main__":
320+
try:
321+
main()
322+
except SystemExit:
323+
raise
324+
except Exception as e:
325+
print(f"[trainer] FATAL: {type(e).__name__}: {e}", file=sys.stderr)
326+
sys.exit(1)

examples/apprentice.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ finetuning:
7272
output_dir: .apprentice/models/
7373
max_concurrent_jobs: 1
7474

75+
# ── Kubernetes LoRA backend (uncomment to enable) ──
76+
# backend: kubernetes_lora
77+
# gcs_bucket: "my-project-apprentice-training"
78+
# training_image: "gcr.io/my-project/apprentice-trainer:latest"
79+
# gpu_type: "nvidia-tesla-t4" # or nvidia-l4
80+
# k8s_namespace: "default"
81+
# service_account: "apprentice-trainer" # K8s SA with GCS access
82+
7583
audit:
7684
log_path: .apprentice/audit.log
7785
log_level: INFO

0 commit comments

Comments
 (0)