Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
b374931
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
840dae5
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
dd5e804
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
e0d0d3d
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
2a2ab51
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
28ceea4
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
20f2fda
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
a50895b
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
f72c088
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
6bf689a
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
c8a8178
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
61b4249
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
257bba0
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
8b03f61
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
792649d
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
9b1f255
feat: Ulysses Anything without paddingc
DefTruth Nov 21, 2025
32b00e5
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
ebd2b4f
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
09fc745
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
6fd79f5
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
c359ba0
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
04c7796
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
8c3acec
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
57c60a5
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
907dcd7
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
bacdcdd
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
a7b0148
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
8e2d8fe
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
d427871
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
4fad710
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
6379206
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
e1071b2
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
d0d3c45
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
dabd60e
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
c4f5ec4
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
df11f36
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
c0c5440
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
790a4b2
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
c458dde
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
372e2c9
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
5b418f9
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
c7572c9
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
42026c3
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
df1db4c
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
e3898c5
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
8451f67
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
74cc09b
feat: Ulysses Anything without padding
DefTruth Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

## 🔥Hightlight

We are excited to announce that the 🎉[**v1.1.0**](https://github.com/vipshop/cache-dit/releases/tag/v1.1.0) version of cache-dit has finally been released! It brings **[🔥Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism)** and **[🔥Tensor Parallelism](./docs/User_Guide.md#️hybrid-tensor-parallelism)** to cache-dit, **thus making** it a Unified and Flexible Inference Engine for 🤗DiTs. Key features: **Unified Cache APIs**, **Forward Pattern Matching**, **Block Adapter**, **DBCache**, **DBPrune**, **Cache CFG**, **TaylorSeer**, **SCM**, **Context Parallelism**, **Tensor Parallelism** and **🎉SOTA** performance.
We are excited to announce that the 🎉[**v1.1.0**](https://github.com/vipshop/cache-dit/releases/tag/v1.1.0) version of cache-dit has finally been released! It brings **[🔥Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism)** and **[🔥Tensor Parallelism](./docs/User_Guide.md#️hybrid-tensor-parallelism)** to cache-dit, **thus making** it a Unified and Flexible Inference Engine for 🤗DiTs. Key features: **Unified Cache APIs**, **Forward Pattern Matching**, **Block Adapter**, **DBCache**, **DBPrune**, **Cache CFG**, **TaylorSeer**, **SCM**, **Context Parallelism (w/ [UAA](./docs/User_Guide.md#uaa-ulysses-anything-attention))**, **Tensor Parallelism** and **🎉SOTA** performance.

```bash
pip3 install -U cache-dit # Also, pip3 install git+https://github.com/huggingface/diffusers.git (latest)
Expand Down Expand Up @@ -256,6 +256,7 @@ For more advanced features such as **Unified Cache APIs**, **Forward Pattern Mat
- [🔥Hybrid TaylorSeer Calibrator](./docs/User_Guide.md#taylorseer-calibrator)
- [🤖SCM: Steps Computation Masking](./docs/User_Guide.md#steps-mask)
- [⚡️Hybrid Context Parallelism](./docs/User_Guide.md#context-parallelism)
- [🤖UAA: Ulysses Anything Attention](./docs/User_Guide.md#ulysses-anything-attention)
- [⚡️Hybrid Tensor Parallelism](./docs/User_Guide.md#tensor-parallelism)
- [🤖Low-bits Quantization](./docs/User_Guide.md#quantization)
- [🤖How to use FP8 Attention](./docs/User_Guide.md#fp8-attention)
Expand Down
Binary file added assets/uaa/flux.C0_Q0_NONE_Ulysses2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions docs/User_Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- [🔥Hybrid TaylorSeer Calibrator](#taylorseer)
- [🤖SCM: Steps Computation Masking](#steps-mask)
- [⚡️Hybrid Context Parallelism](#context-parallelism)
- [🤖UAA: Ulysses Anything Attention](#ulysses-anything-attention)
- [⚡️Hybrid Tensor Parallelism](#tensor-parallelism)
- [🤖Low-bits Quantization](#quantization)
- [🤖How to use FP8 Attention](#fp8-attention)
Expand Down Expand Up @@ -714,6 +715,50 @@ cache_dit.enable_cache(
# torchrun --nproc_per_node=2 parallel_cache.py
```

## 🤖UAA: Ulysses Anything Attention

<div id="ulysses-anything-attention"></div>

We have implemented **[📚UAA: Ulysses Anything Attention](#uaa-ulysses-anything-attention)**: An Ulysses Attention that supports **arbitrary seq_len** with nearly 🎉**Zero overhead** (namely, **✅~0** communication overhead, **✅~0** IO access overhead and **✅~0** values padding). As we know, the default Ulysses Attention requires that the seq_len of the input hidden_states **must be divisible by the number of devices**. This imposes **significant limitations** on the practical application of Ulysses.

For example, in the Text-to-Image and Image-to-Video tasks, the length of prompts input by users is often variable, and it is difficult to ensure that this length is divisible by the number of devices. To address this issue, we have developed a **padding-free** Ulysses Attention (UAA) for **arbitrary seq_len**, which enhances the versatility of Ulysses.

```python
# pip3 install "cache-dit[parallelism]"
from cache_dit import ParallelismConfig

cache_dit.enable_cache(
pipe_or_adapter,
cache_config=DBCacheConfig(...),
# Set `experimental_ulysses_anything` as True to enable UAA
parallelism_config=ParallelismConfig(
ulysses_size=2,
parallel_kwargs={
"experimental_ulysses_anything": True
},
),
)
# torchrun --nproc_per_node=2 parallel_cache_ulysses_anything.py
```

Compared to Ulysses Attention, in **UAA**, we have only added an **extra all-gather** op for scalar types to gather the seq_len value of each rank. To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **✅gloo** backend in `init_process_group`. This will significantly reduce comm latency.

```python
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
```

Please note that Ulysses Anything Attention is currently an **experimental** feature; it has not undergone large-scale testing, and it mat introduce a slight degradation of performance while the `cpu:gloo` backend is not avaliable.


<div align="center">

|L20x2, Ulysses |UAA w/ Gloo | UAA w/o Gloo | L20x2 w/ Ulysses | L20x4 w/ UAA |
|:---:|:---:|:---:|:---:|:---:|
|FLUX.1, 13.87s|🎉13.88s|14.75s|Qwen-Image| ❌Ulysses failed|
|<img src="../assets/uaa/flux.C0_Q0_NONE_Ulysses2.png" width=180px>|<img src="../assets/uaa/flux.C0_Q0_NONE_Ulysses2_ulysses_anything.png" width=180px>|<img src="../assets/uaa/flux.C0_Q0_NONE_Ulysses2_ulysses_anything.png" width=180px>|<img src="../assets/uaa/qwen-image.C1_Q1_float8_weight_only_NONE_Ulysses2.png" width=180px>|<img src="../assets/uaa/qwen-image.C1_Q1_float8_weight_only_NONE_Ulysses4_ulysses_anything.png" width=180px>|

</div>

## ⚡️Hybrid Tensor Parallelism

<div id="tensor-parallelism"></div>
Expand Down
19 changes: 16 additions & 3 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def get_args(
default=False,
help="Track and report peak GPU memory usage",
)
parser.add_argument(
"--ulysses-anything",
action="store_true",
default=False,
help="Enable Ulysses Anything Attention for context parallelism",
)
return parser.parse_args() if parse else parser


Expand All @@ -149,8 +155,12 @@ def cachify(
if args.parallel_type in ["tp"]
else ParallelismBackend.NATIVE_DIFFUSER
)

parallel_kwargs = (
{"attention_backend": ("_native_cudnn" if not args.attn else args.attn)}
{
"attention_backend": ("_native_cudnn" if not args.attn else args.attn),
"experimental_ulysses_anything": args.ulysses_anything,
}
if backend == ParallelismBackend.NATIVE_DIFFUSER
else None
)
Expand Down Expand Up @@ -199,17 +209,20 @@ def strify(args, pipe_or_stats):
quantize_type = args.quantize_type if args.quantize else ""
if quantize_type != "":
quantize_type = f"_{quantize_type}"
return (
base_str = (
f"C{int(args.compile)}_Q{int(args.quantize)}{quantize_type}_"
f"{cache_dit.strify(pipe_or_stats)}"
)
if args.ulysses_anything:
base_str += "_ulysses_anything"
return base_str


def maybe_init_distributed(args=None):
if args is not None:
if args.parallel_type is not None:
dist.init_process_group(
backend="nccl",
backend="cpu:gloo,cuda:nccl" if args.ulysses_anything else "nccl",
)
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
Expand Down
22 changes: 22 additions & 0 deletions src/cache_dit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ def quantize(*args, **kwargs):
)


try:
from cache_dit.parallelism import disable_ulysses_anything
from cache_dit.parallelism import enable_ulysses_anything

except ImportError as e: # noqa: F841
err_msg = str(e)

def enable_ulysses_anything(*args, **kwargs):
raise ImportError(
"Ulysses Anything Attention requires additional dependencies. "
"Please install cache-dit[parallelism] or cache-dit[all] "
f"to use this feature. Error message: {err_msg}"
)

def disable_ulysses_anything(*args, **kwargs):
raise ImportError(
"Ulysses Anything Attention requires additional dependencies. "
"Please install cache-dit[parallelism] or cache-dit[all] "
f"to use this feature. Error message: {err_msg}"
)


NONE = CacheType.NONE
DBCache = CacheType.DBCache
DBPrune = CacheType.DBPrune
Expand Down
2 changes: 2 additions & 0 deletions src/cache_dit/parallelism/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from cache_dit.parallelism.parallel_backend import ParallelismBackend
from cache_dit.parallelism.parallel_config import ParallelismConfig
from cache_dit.parallelism.backends.native_diffusers import enable_ulysses_anything
from cache_dit.parallelism.backends.native_diffusers import disable_ulysses_anything
from cache_dit.parallelism.parallel_interface import enable_parallelism
from cache_dit.parallelism.parallel_interface import maybe_pad_prompt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from cache_dit.parallelism.backends.native_diffusers.context_parallelism import (
ContextParallelismPlannerRegister,
)
from cache_dit.parallelism.backends.native_diffusers.context_parallelism.attention import (
enable_ulysses_anything,
)
from cache_dit.parallelism.backends.native_diffusers.context_parallelism.attention import (
disable_ulysses_anything,
)
from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
maybe_enable_parallelism,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ContextParallelConfig,
)
from .attention import maybe_resigter_native_attention_backend
from .attention import enable_ulysses_anything
from .cp_planners import *

try:
Expand Down Expand Up @@ -46,6 +47,12 @@ def maybe_enable_context_parallelism(
ring_degree=parallelism_config.ring_size,
)
if cp_config is not None:
experimental_ulysses_anything = parallelism_config.parallel_kwargs.get(
"experimental_ulysses_anything", False
)
if experimental_ulysses_anything:
enable_ulysses_anything()

attention_backend = parallelism_config.parallel_kwargs.get("attention_backend", None)
if hasattr(transformer, "enable_parallelism"):
if hasattr(transformer, "set_attention_backend"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@ def maybe_resigter_native_attention_backend():
"""Maybe re-register native attention backend to enable context parallelism."""
# Import custom attention backend ensuring registration
from ._attention_dispatch import _native_attention


from ._templated_ulysses_anything import (
enable_ulysses_anything,
is_ulysses_anything_enabled,
disable_ulysses_anything,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
_check_shape,
TemplatedRingAttention,
TemplatedUlyssesAttention,
# _all_to_all_single,
)
from diffusers.models._modeling_parallel import ParallelConfig
except ImportError:
Expand All @@ -20,6 +19,9 @@
"pip3 install git+https://github.com/huggingface/diffusers.git"
)
from cache_dit.logger import init_logger
from ._templated_ulysses_anything import TemplatedUlyssesAnythingAttention
from ._templated_ulysses_anything import is_ulysses_anything_enabled


logger = init_logger(__name__)

Expand Down Expand Up @@ -117,20 +119,36 @@ def _templated_context_parallel_attention_v2(
_parallel_config,
)
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
return TemplatedUlyssesAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if is_ulysses_anything_enabled():
return TemplatedUlyssesAnythingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
return TemplatedUlyssesAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")

Expand Down
Loading