Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 113 additions & 14 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,39 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from pydantic import BaseModel, field_validator, model_validator, Field
from pydantic_core.core_schema import ValidationInfo # for field_validator(info)
from typing import List, Optional, Literal
import os, torch

from pydantic import BaseModel
from typing import List, Optional, Literal, Type
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# centralize all supported encoder names (add dinov3).
EncoderName = Literal[
"dinov2_windowed_small",
"dinov2_windowed_base",
"dinov3_small",
"dinov3_base",
"dinov3_large",
]

def _encoder_default():
"""Default encoder name for the model config."""
# default to v2 unless explicitly overridden by env
val = os.getenv("RFD_ENCODER", "").strip() or "dinov2_windowed_small"

# guardrail: only accept known names
allowed = {
"dinov2_windowed_small","dinov2_windowed_base",
"dinov3_small","dinov3_base","dinov3_large"
}
return val if val in allowed else "dinov2_windowed_small"

class ModelConfig(BaseModel):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
"""Base configuration for RF-DETR models."""
# WAS: only dinov2_windowed_*; NOW: include dinov3_* as drop-in options
encoder: EncoderName = _encoder_default()

out_feature_indexes: List[int]
dec_layers: int
two_stage: bool = True
Expand All @@ -33,39 +58,109 @@ class ModelConfig(BaseModel):
group_detr: int = 13
gradient_checkpointing: bool = False
positional_encoding_size: int
# used only when encoder startswith("dinov3")
dinov3_repo_dir: Optional[str] = None # e.g., r"D:\repos\dinov3"
dinov3_weights_path: Optional[str] = None # e.g., r"C:\models\dinov3-vitb16.pth"
dinov3_hf_token: Optional[str] = None # or rely on HUGGINGFACE_HUB_TOKEN
dinov3_prefer_hf: bool = True # try HF first, then hub fallback

# force /16 for v3
@field_validator("patch_size", mode="after")
def _coerce_patch_for_dinov3(cls, v, info: ValidationInfo):
"""Ensure patch size is 16 for DINOv3 encoders."""
enc = str(info.data.get("encoder", ""))
return 16 if enc.startswith("dinov3") else v

# keep pos-encoding grid consistent with resolution / patch
@field_validator("positional_encoding_size", mode="after")
def _sync_pos_enc_with_resolution(cls, v, info: ValidationInfo):
"""Sync positional encoding size with resolution and patch size."""
values = info.data or {}
res, ps = values.get("resolution"), values.get("patch_size")
return max(1, res // ps) if (res and ps) else v

# env fallbacks for local repo/weights when *not* preferring HF
@field_validator("dinov3_repo_dir", "dinov3_weights_path", mode="after")
def _fallback_to_env(cls, v, info: ValidationInfo):
"""Fallback to environment variables if not set."""
values = info.data or {}
if (not v) and str(values.get("encoder","")).startswith("dinov3") and not values.get("dinov3_prefer_hf", True):
env_map = {"dinov3_repo_dir": "DINOV3_REPO", "dinov3_weights_path": "DINOV3_WEIGHTS"}
env_key = env_map[info.field_name]
return os.getenv(env_key, v)
return v

# neutralize windowing for v3 (avoid accidental asserts downstream)
@field_validator("num_windows", mode="after")
def _neutralize_windows_for_dinov3(cls, v, info: ValidationInfo):
"""Neutralize windowing for DINOv3 encoders."""
enc = str((info.data or {}).get("encoder",""))
return 1 if enc.startswith("dinov3") else v

# auto-fit out_feature_indexes to avoid projector shape mismatches
@field_validator("out_feature_indexes", mode="after")
def _coerce_out_feats_for_backbone(cls, v, info: ValidationInfo):
"""Ensure out_feature_indexes are compatible with the encoder."""
enc = str((info.data or {}).get("encoder",""))
if enc.startswith("dinov3"):
# DINOv3 path: default to fewer, stable high-level features
return v if len(v) in (2,) else [8, 11]
return v

# Final safety net: once the whole model is built, normalize settings for DINOv3.
@model_validator(mode="after")
def _final_autofix_for_dinov3(self):
"""Final adjustments after model construction."""
enc = str(self.encoder)
if enc.startswith("dinov3"):
# enforce /16 patch + matching pos-enc grid
self.patch_size = 16
if self.resolution:
self.positional_encoding_size = max(1, self.resolution // self.patch_size)
# windowing is a no-op for v3
self.num_windows = 1
# most important: use 2 high-level features to match projector weights across v2/v3
if len(self.out_feature_indexes) != 2:
self.out_feature_indexes = [8, 11]
return self

class RFDETRBaseConfig(ModelConfig):
"""
The configuration for an RF-DETR Base model.
"""
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
# Allow choosing dinov3_* without changing call sites
encoder: EncoderName = _encoder_default()
print("Using RFDETRBaseConfig with encoder:", encoder)
hidden_dim: int = 256
patch_size: int = 14
num_windows: int = 4
patch_size: int = 14 # will auto-become 16 if encoder startswith("dinov3")
num_windows: int = 4 # ignored by DINOv3 branch
dec_layers: int = 3
sa_nheads: int = 8
ca_nheads: int = 16
dec_n_points: int = 2
num_queries: int = 300
num_select: int = 300
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
out_feature_indexes: List[int] = [2, 5, 8, 11]
out_feature_indexes: List[int] = [2, 4, 5, 9]
pretrain_weights: Optional[str] = "rf-detr-base.pth"
resolution: int = 560
positional_encoding_size: int = 37
#resolution: int = 504 # 560//16=35 when dinov3_* is used
resolution: int = 512 # 512//16=32 → pos grid auto=32 for both v2/v3
positional_encoding_size: int = 36 # will auto-sync to resolution//patch_size


class RFDETRLargeConfig(RFDETRBaseConfig):
"""
The configuration for an RF-DETR Large model.
"""
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
encoder: EncoderName = "dinov2_windowed_base"
hidden_dim: int = 384
sa_nheads: int = 12
ca_nheads: int = 24
dec_n_points: int = 4
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
pretrain_weights: Optional[str] = "rf-detr-large.pth"


class RFDETRNanoConfig(RFDETRBaseConfig):
"""
The configuration for an RF-DETR Nano model.
Expand All @@ -74,10 +169,11 @@ class RFDETRNanoConfig(RFDETRBaseConfig):
num_windows: int = 2
dec_layers: int = 2
patch_size: int = 16
resolution: int = 384
resolution: int = 384 # 384//16=24 → pos grid auto=24 for both v2/v3
positional_encoding_size: int = 24
pretrain_weights: Optional[str] = "rf-detr-nano.pth"


class RFDETRSmallConfig(RFDETRBaseConfig):
"""
The configuration for an RF-DETR Small model.
Expand All @@ -86,10 +182,11 @@ class RFDETRSmallConfig(RFDETRBaseConfig):
num_windows: int = 2
dec_layers: int = 3
patch_size: int = 16
resolution: int = 512
resolution: int = 512 # 512//16=32 → pos grid auto=32
positional_encoding_size: int = 32
pretrain_weights: Optional[str] = "rf-detr-small.pth"


class RFDETRMediumConfig(RFDETRBaseConfig):
"""
The configuration for an RF-DETR Medium model.
Expand All @@ -98,10 +195,12 @@ class RFDETRMediumConfig(RFDETRBaseConfig):
num_windows: int = 2
dec_layers: int = 4
patch_size: int = 16
resolution: int = 576
#resolution: int = 504 # 576//16=36 → pos grid auto=36
resolution: int = 512
positional_encoding_size: int = 36
pretrain_weights: Optional[str] = "rf-detr-medium.pth"


class TrainConfig(BaseModel):
lr: float = 1e-4
lr_encoder: float = 1.5e-4
Expand Down
43 changes: 25 additions & 18 deletions rfdetr/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import sys
from typing import Iterable
import random

from contextlib import nullcontext
import torch
import torch.nn.functional as F

Expand All @@ -39,12 +39,16 @@
from rfdetr.util.misc import NestedTensor
import numpy as np


def get_autocast_args(args):
"""Return autocast arguments based on the DEPRECATED_AMP flag and args."""
use_cuda = torch.cuda.is_available()
enabled = bool(getattr(args, "amp", False) and use_cuda)
if DEPRECATED_AMP:
return {'enabled': args.amp, 'dtype': torch.bfloat16}
return {"enabled": enabled, "dtype": torch.bfloat16}
else:
return {'device_type': 'cuda', 'enabled': args.amp, 'dtype': torch.bfloat16}

# only use CUDA autocast when CUDA exists
return {"device_type": "cuda", "enabled": enabled, "dtype": torch.bfloat16}

def train_one_epoch(
model: torch.nn.Module,
Expand Down Expand Up @@ -75,11 +79,11 @@ def train_one_epoch(
print("Grad accum steps: ", args.grad_accum_steps)
print("Total batch size: ", batch_size * utils.get_world_size())

# Add gradient scaler for AMP
use_amp = bool(getattr(args, "amp", False) and torch.cuda.is_available())
if DEPRECATED_AMP:
scaler = GradScaler(enabled=args.amp)
scaler = GradScaler(enabled=use_amp)
else:
scaler = GradScaler('cuda', enabled=args.amp)
scaler = GradScaler("cuda", enabled=use_amp)

optimizer.zero_grad()
assert batch_size % args.grad_accum_steps == 0
Expand Down Expand Up @@ -113,7 +117,9 @@ def train_one_epoch(
scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows)
random.seed(it)
scale = random.choice(scales)
with torch.inference_mode():
# DO NOT use inference_mode() here; it creates inference tensors
#with torch.inference_mode():
with torch.no_grad():
samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False)
samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool()

Expand All @@ -124,16 +130,17 @@ def train_one_epoch(
new_samples = NestedTensor(new_samples_tensors, samples.mask[start_idx:final_idx])
new_samples = new_samples.to(device)
new_targets = [{k: v.to(device) for k, v in t.items()} for t in targets[start_idx:final_idx]]

with autocast(**get_autocast_args(args)):
outputs = model(new_samples, new_targets)
loss_dict = criterion(outputs, new_targets)
weight_dict = criterion.weight_dict
losses = sum(
(1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
for k in loss_dict.keys()
if k in weight_dict
)
torch.set_grad_enabled(True) # safety
with torch.inference_mode(False):
with autocast(**get_autocast_args(args)):
outputs = model(new_samples, new_targets)
loss_dict = criterion(outputs, new_targets)
weight_dict = criterion.weight_dict
losses = sum(
(1 / args.grad_accum_steps) * loss_dict[k] * weight_dict[k]
for k in loss_dict.keys()
if k in weight_dict
)


scaler.scale(losses).backward()
Expand Down
Loading