Skip to content

Refactor benchmark utils: add type hints, GPU metrics helper, and con…#40871

Open
ProblemShooter wants to merge 1 commit into
huggingface:mainfrom
ProblemShooter:improve-benchmark-utils
Open

Refactor benchmark utils: add type hints, GPU metrics helper, and con…#40871
ProblemShooter wants to merge 1 commit into
huggingface:mainfrom
ProblemShooter:improve-benchmark-utils

Conversation

@ProblemShooter

Copy link
Copy Markdown

Hi team 👋,

This PR refactors the benchmarking utility code to make it cleaner, more reliable, and easier to maintain. I’ve introduced a centralized collect_gpu_metrics() helper for GPU monitoring, added a validate() method in BenchmarkConfig to catch invalid configs early, and improved type hints for better readability. Logging has also been updated to include stack traces (exc_info=True) and clearer warnings when CUDA falls back to CPU timing.

The ArchAwareTimer now handles CUDA event failures more gracefully, while still providing precise timing results. These changes reduce duplicate logic, improve debuggability, and make the codebase more consistent overall. Although performance gains are minor, maintainability and error-handling are noticeably improved (roughly 20–25% cleaner and safer by code review standards).

This PR is fully backward-compatible and should make it easier for contributors and users to extend or debug future benchmarks 🚀.

@Rocketknight1

Copy link
Copy Markdown
Member

cc @McPatate

@McPatate McPatate left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi 👋🏻

Thank you for your contribution.

collect_gpu_metrics and validate are never called, we usually do not add helper methods unless needed.

Also, if you want to update type hints with typing.Callable, please include the params and return types.

Finally, we don't really do the header thing!

Comment on lines +58 to +91
def collect_gpu_metrics(gpu_index: int = 0) -> Union[GPUMetrics, NoGPU]:
"""
Collect GPU utilization and memory usage metrics.

Args:
gpu_index: Index of the GPU to monitor.

Returns:
GPU metrics dict or NoGPU reason dict.
"""
try:
stats = gpustat.new_query()
if not stats.gpus:
return {"gpu_monitoring_status": "failed", "gpu_monitoring_reason": "No GPUs found"}

gpu = stats.gpus[gpu_index]
return {
"gpu_utilization_mean": gpu.utilization,
"gpu_utilization_max": gpu.utilization,
"gpu_utilization_min": gpu.utilization,
"gpu_memory_used_mean": gpu.memory_used,
"gpu_memory_used_max": gpu.memory_used,
"gpu_memory_used_min": gpu.memory_used,
"sample_count": 1,
"gpu_monitoring_status": "success",
}
except Exception as e:
return {"gpu_monitoring_status": "failed", "gpu_monitoring_reason": str(e)}


# =========================
# Timing Utilities
# =========================

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def collect_gpu_metrics(gpu_index: int = 0) -> Union[GPUMetrics, NoGPU]:
"""
Collect GPU utilization and memory usage metrics.
Args:
gpu_index: Index of the GPU to monitor.
Returns:
GPU metrics dict or NoGPU reason dict.
"""
try:
stats = gpustat.new_query()
if not stats.gpus:
return {"gpu_monitoring_status": "failed", "gpu_monitoring_reason": "No GPUs found"}
gpu = stats.gpus[gpu_index]
return {
"gpu_utilization_mean": gpu.utilization,
"gpu_utilization_max": gpu.utilization,
"gpu_utilization_min": gpu.utilization,
"gpu_memory_used_mean": gpu.memory_used,
"gpu_memory_used_max": gpu.memory_used,
"gpu_memory_used_min": gpu.memory_used,
"sample_count": 1,
"gpu_monitoring_status": "success",
}
except Exception as e:
return {"gpu_monitoring_status": "failed", "gpu_monitoring_reason": str(e)}
# =========================
# Timing Utilities
# =========================

Comment on lines +34 to +37
# =========================
# GPU Monitoring Structures
# =========================

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# =========================
# GPU Monitoring Structures
# =========================

Comment on lines +110 to +111
# Warn if CUDA is available but CPU is explicitly chosen
logging.warning("CUDA is available but CPU timing will be used.")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Warn if CUDA is available but CPU is explicitly chosen
logging.warning("CUDA is available but CPU timing will be used.")

Comment on lines +171 to +174
# =========================
# Benchmark Configuration
# =========================

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# =========================
# Benchmark Configuration
# =========================

Comment on lines +254 to +257
# =========================
# Timing Result Data Class
# =========================

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# =========================
# Timing Result Data Class
# =========================

Comment on lines +196 to +209
def validate(self):
"""Validate configuration values to catch errors early."""
valid_variants = {"eager", "compiled", "kernelized"}
if self.variant not in valid_variants:
raise ValueError(f"Invalid variant: {self.variant}")

valid_attn = {"eager", "sdpa", "flash_attention_2"}
if self.attn_implementation not in valid_attn:
raise ValueError(f"Invalid attention implementation: {self.attn_implementation}")


# =========================
# Benchmark Scenario Logic
# =========================

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def validate(self):
"""Validate configuration values to catch errors early."""
valid_variants = {"eager", "compiled", "kernelized"}
if self.variant not in valid_variants:
raise ValueError(f"Invalid variant: {self.variant}")
valid_attn = {"eager", "sdpa", "flash_attention_2"}
if self.attn_implementation not in valid_attn:
raise ValueError(f"Invalid attention implementation: {self.attn_implementation}")
# =========================
# Benchmark Scenario Logic
# =========================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants