Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/PROFILER.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Torch Profiler Usage

Reference: Adapted from `sglang/python/sglang/bench_one_batch.py` and `sglang/python/sglang/srt/managers/scheduler_profiler_mixin.py`

Reference: Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_one_batch.py .
## Quick Start

### Basic Usage
Expand Down
15 changes: 13 additions & 2 deletions examples/parallelism/run_flux_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
maybe_init_distributed,
maybe_destroy_distributed,
MemoryTracker,
create_profiler_from_args,
)
import cache_dit

Expand Down Expand Up @@ -69,11 +70,14 @@


def run_pipe(pipe: FluxPipeline):
steps = 28 if args.steps is None else args.steps
if args.profile and args.steps is None:
steps = 3
image = pipe(
prompt,
height=height,
width=width,
num_inference_steps=28 if args.steps is None else args.steps,
num_inference_steps=steps,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
return image
Expand All @@ -91,7 +95,14 @@ def run_pipe(pipe: FluxPipeline):
memory_tracker.__enter__()

start = time.time()
image = run_pipe(pipe)
if args.profile:
profiler = create_profiler_from_args(args, profile_name="flux_cp_inference")
with profiler:
image = run_pipe(pipe)
if rank == 0:
print(f"Profiler traces saved to: {profiler.output_dir}/{profiler.trace_path.name}")
else:
image = run_pipe(pipe)
end = time.time()

if memory_tracker:
Expand Down
15 changes: 13 additions & 2 deletions examples/parallelism/run_flux_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
maybe_init_distributed,
maybe_destroy_distributed,
MemoryTracker,
create_profiler_from_args,
)
import cache_dit

Expand Down Expand Up @@ -64,11 +65,14 @@


def run_pipe(warmup: bool = False):
steps = 5 if warmup else (28 if args.steps is None else args.steps)
if args.profile and args.steps is None and not warmup:
steps = 3
image = pipe(
prompt,
height=1024 if args.height is None else args.height,
width=1024 if args.width is None else args.width,
num_inference_steps=5 if warmup else (28 if args.steps is None else args.steps),
num_inference_steps=steps,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
return image
Expand All @@ -86,7 +90,14 @@ def run_pipe(warmup: bool = False):
memory_tracker.__enter__()

start = time.time()
image = run_pipe()
if args.profile:
profiler = create_profiler_from_args(args, profile_name="flux_tp_inference")
with profiler:
image = run_pipe()
if rank == 0:
print(f"Profiler traces saved to: {profiler.output_dir}/{profiler.trace_path.name}")
else:
image = run_pipe()
end = time.time()

if memory_tracker:
Expand Down
3 changes: 1 addition & 2 deletions src/cache_dit/profiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""
Torch Profiler for cache-dit.

Reference: Adapted from sglang/python/sglang/bench_one_batch.py and
sglang/python/sglang/srt/managers/scheduler_profiler_mixin.py
Reference: Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_one_batch.py
"""

import logging
Expand Down