Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
| **🎉[Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen-Image 4-bits](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[Qwen...Lightning](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...Lightning 4-bits](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[CogVideoX](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[OmniGen](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
| **🎉[Wan 2.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Sigma](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[Wan 2.1 VACE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Alpha](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[Wan 2.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Sigma](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | |
| **🎉[Wan 2.1 VACE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Alpha](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | |
| **🎉[Wan 2.2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[CogVideoX 1.5](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ |
| **🎉[HunyuanVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Sana](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
| **🎉[LTXVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ | **🎉[VisualCloze](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ |
Expand Down
4 changes: 2 additions & 2 deletions docs/User_Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
| **🎉[Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen-Image 4-bits](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[Qwen...Lightning](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...Lightning 4-bits](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[CogVideoX](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[OmniGen](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
| **🎉[Wan 2.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Sigma](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[Wan 2.1 VACE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Alpha](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
| **🎉[Wan 2.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Sigma](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | |
| **🎉[Wan 2.1 VACE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PixArt Alpha](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | |
| **🎉[Wan 2.2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[CogVideoX 1.5](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ |
| **🎉[HunyuanVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Sana](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
| **🎉[LTXVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ | **🎉[VisualCloze](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ |
Expand Down
93 changes: 93 additions & 0 deletions examples/parallelism/run_pixart_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import sys

sys.path.append("..")

import time
import torch
from diffusers import (
Transformer2DModel,
PixArtSigmaPipeline,
PixArtAlphaPipeline,
)
from utils import (
get_args,
strify,
cachify,
maybe_init_distributed,
maybe_destroy_distributed,
)
import cache_dit


args = get_args()
print(args)

rank, device = maybe_init_distributed(args)

# Support both PixArt-Alpha and PixArt-Sigma models
model_id = os.environ.get(
"PIXART_DIR",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
# Alternative models:
# "PixArt-alpha/PixArt-XL-2-1024-MS",
# "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
)

# Determine pipeline type based on model
if "Sigma" in model_id:
pipeline_class = PixArtSigmaPipeline
else:
pipeline_class = PixArtAlphaPipeline

transformer = Transformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
torch_dtype=torch.bfloat16,
use_safetensors=True,
)

pipe = pipeline_class.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=torch.bfloat16,
use_safetensors=True,
)

if args.cache or args.parallel_type is not None:
cachify(args, pipe)

torch.cuda.empty_cache()
pipe.enable_model_cpu_offload(device=device)
pipe.set_progress_bar_config(disable=rank != 0)


def run_pipe(warmup: bool = False):
image = pipe(
"A small cactus with a happy face in the Sahara desert.",
height=1024 if args.height is None else args.height,
width=1024 if args.width is None else args.width,
num_inference_steps=50 if not warmup else 5,
generator=torch.Generator(device="cpu").manual_seed(42),
).images[0]
return image


# Warmup
_ = run_pipe(warmup=True)

start = time.time()
image = run_pipe()
end = time.time()

if rank == 0:
stats = cache_dit.summary(pipe)
time_cost = end - start
model_name = "pixart-sigma" if "Sigma" in model_id else "pixart-alpha"
save_path = f"{model_name}.{strify(args, stats)}.png"

print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)

maybe_destroy_distributed()
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from torch import nn
from torch.distributed import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

from cache_dit.logger import init_logger
from cache_dit.parallelism.parallel_config import ParallelismConfig

from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)

logger = init_logger(__name__)


@TensorParallelismPlannerRegister.register("PixArt")
class PixArtTensorParallelismPlanner(TensorParallelismPlanner):
def apply(
self,
transformer: nn.Module,
parallelism_config: ParallelismConfig,
**_kwargs,
) -> nn.Module:
assert (
parallelism_config.tp_size is not None
and parallelism_config.tp_size > 1
), (
"parallel_config.tp_size must be set and greater than 1 for "
"tensor parallelism"
)

device_type = torch.accelerator.current_accelerator().type
tp_mesh: DeviceMesh = init_device_mesh(
device_type=device_type,
mesh_shape=[parallelism_config.tp_size],
)

transformer = self.parallelize_transformer(
transformer=transformer,
tp_mesh=tp_mesh,
)

return transformer

def parallelize_transformer(
self,
transformer: nn.Module,
tp_mesh: DeviceMesh,
):
"""
Parallelize PixArt transformer blocks.

PixArt uses BasicTransformerBlock with:
- Self-attention (attn1)
- Cross-attention (attn2)
- Feed-forward network (ff)
- Standard normalization layers
"""
for i, block in enumerate(transformer.transformer_blocks):
# Split attention heads across TP devices
block.attn1.heads //= tp_mesh.size()
block.attn2.heads //= tp_mesh.size()

# Create layer plan for tensor parallelism
layer_plan = {
# Self-attention projections (column-wise)
"attn1.to_q": ColwiseParallel(),
"attn1.to_k": ColwiseParallel(),
"attn1.to_v": ColwiseParallel(),
"attn1.to_out.0": RowwiseParallel(),
# Cross-attention projections (column-wise)
"attn2.to_q": ColwiseParallel(),
"attn2.to_k": ColwiseParallel(),
"attn2.to_v": ColwiseParallel(),
"attn2.to_out.0": RowwiseParallel(),
# Feed-forward network
"ff.net.0.proj": ColwiseParallel(),
"ff.net.2": RowwiseParallel(),
}

# Apply tensor parallelism to the block
parallelize_module(
module=block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

logger.debug(
f"Parallelized PixArt block {i} with TP size {tp_mesh.size()}"
)

return transformer
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .tp_plan_hunyuan_dit import HunyuanDiTTensorParallelismPlanner
from .tp_plan_kandinsky5 import Kandinsky5TensorParallelismPlanner
from .tp_plan_mochi import MochiTensorParallelismPlanner
from .tp_plan_pixart import PixArtTensorParallelismPlanner
from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
from .tp_plan_registers import TensorParallelismPlannerRegister
from .tp_plan_wan import WanTensorParallelismPlanner
Expand All @@ -14,6 +15,7 @@
"HunyuanDiTTensorParallelismPlanner",
"Kandinsky5TensorParallelismPlanner",
"MochiTensorParallelismPlanner",
"PixArtTensorParallelismPlanner",
"QwenImageTensorParallelismPlanner",
"TensorParallelismPlannerRegister",
"WanTensorParallelismPlanner",
Expand Down