Skip to content
Merged

RDM #111

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
</p>

## News

### July 2022
- Inference code and model weights to run our [retrieval-augmented diffusion models](https://arxiv.org/abs/2204.11824) are now available. See [this section](#retrieval-augmented-diffusion-models).
### April 2022
- Thanks to [Katherine Crowson](https://github.com/crowsonkb), classifier-free guidance received a ~2x speedup and the [PLMS sampler](https://arxiv.org/abs/2202.09778) is available. See also [this PR](https://github.com/CompVis/latent-diffusion/pull/51).

Expand All @@ -42,6 +45,74 @@ conda activate ldm
A general list of all available checkpoints is available in via our [model zoo](#model-zoo).
If you use any of these models in your work, we are always happy to receive a [citation](#bibtex).

## Retrieval Augmented Diffusion Models
![rdm-figure](assets/rdm-preview.jpg)
We include inference code to run our retrieval-augmented diffusion models (RDMs) as described in [https://arxiv.org/abs/2204.11824](https://arxiv.org/abs/2204.11824).


To get started, install the additionally required python packages into your `ldm` environment
```shell script
pip install transformers==4.19.2 scann kornia==0.6.4 torchmetrics==0.6.0
pip install git+https://github.com/arogozhnikov/einops.git
```
and download the trained weights (preliminary ceckpoints):

```bash
mkdir -p models/rdm/rdm768x768/
wget -O models/rdm/rdm768x768/model.ckpt https://ommer-lab.com/files/rdm/model.ckpt
```
As these models are conditioned on a set of CLIP image embeddings, our RDMs support different inference modes,
which are described in the following.
#### RDM with text-prompt only (no explicit retrieval needed)
Since CLIP offers a shared image/text feature space, and RDMs learn to cover a neighborhood of a given
example during training, we can directly take a CLIP text embedding of a given prompt and condition on it.
Run this mode via
```
python scripts/knn2img.py --prompt "a happy bear reading a newspaper, oil on canvas"
```

#### RDM with text-to-image retrieval

To be able to run a RDM conditioned on a text-prompt and additionally images retrieved from this prompt, you will also need to download the corresponding retrieval database.
We provide two distinct databases extracted from the [Openimages-](https://storage.googleapis.com/openimages/web/index.html) and [ArtBench-](https://github.com/liaopeiyuan/artbench) datasets.
Interchanging the databases results in different capabilities of the model as visualized below, although the learned weights are the same in both cases.

Download the retrieval-databases which contain the retrieval-datasets ([Openimages](https://storage.googleapis.com/openimages/web/index.html) (~11GB) and [ArtBench](https://github.com/liaopeiyuan/artbench) (~82MB)) compressed into CLIP image embeddings:
```bash
mkdir -p data/rdm/retrieval_databases
wget -O data/rdm/retrieval_databases/artbench.zip https://ommer-lab.com/files/rdm/artbench_databases.zip
wget -O data/rdm/retrieval_databases/openimages.zip https://ommer-lab.com/files/rdm/openimages_database.zip
unzip data/rdm/retrieval_databases/artbench.zip -d data/rdm/retrieval_databases/
unzip data/rdm/retrieval_databases/openimages.zip -d data/rdm/retrieval_databases/
```
We also provide trained [ScaNN](https://github.com/google-research/google-research/tree/master/scann) search indices for ArtBench. Download and extract via
```bash
mkdir -p data/rdm/searchers
wget -O data/rdm/searchers/artbench.zip https://ommer-lab.com/files/rdm/artbench_searchers.zip
unzip data/rdm/searchers/artbench.zip -d data/rdm/searchers
```

Since the index for OpenImages is large (~21 GB), we provide a script to create and save it for usage during sampling. Note however,
that sampling with the OpenImages database will not be possible without this index. Run the script via
```bash
python scripts/train_searcher.py
```

Retrieval based text-guided sampling with visual nearest neighbors can be started via
```
python scripts/knn2img.py --prompt "a happy pineapple" --use_neighbors --knn <number_of_neighbors>
```
Note that the maximum supported number of neighbors is 20.
The database can be changed via the cmd parameter ``--database`` which can be `[openimages, artbench-art_nouveau, artbench-baroque, artbench-expressionism, artbench-impressionism, artbench-post_impressionism, artbench-realism, artbench-renaissance, artbench-romanticism, artbench-surrealism, artbench-ukiyo_e]`.
For using `--database openimages`, the above script (`scripts/train_searcher.py`) must be executed before.
Due to their relatively small size, the artbench datasetbases are best suited for creating more abstract concepts and do not work well for detailed text control.


#### Coming Soon
- better models
- more resolutions
- image-to-image retrieval

## Text-to-Image
![text2img-figure](assets/txt2img-preview.png)

Expand Down Expand Up @@ -273,6 +344,19 @@ Thanks for open-sourcing!
archivePrefix={arXiv},
primaryClass={cs.CV}
}

@misc{https://doi.org/10.48550/arxiv.2204.11824,
doi = {10.48550/ARXIV.2204.11824},
url = {https://arxiv.org/abs/2204.11824},
author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and Ommer, Björn},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Retrieval-Augmented Diffusion Models},
publisher = {arXiv},
year = {2022},
copyright = {arXiv.org perpetual, non-exclusive license}
}


```


Binary file added assets/rdm-preview.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 68 additions & 0 deletions configs/retrieval-augmented-diffusion/768x768.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
model:
base_learning_rate: 0.0001
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.015
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: jpg
cond_stage_key: nix
image_size: 48
channels: 16
cond_stage_trainable: false
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_by_std: false
scale_factor: 0.22765929
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 48
in_channels: 16
out_channels: 16
model_channels: 448
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
use_scale_shift_norm: false
resblock_updown: false
num_head_channels: 32
use_spatial_transformer: true
transformer_depth: 1
context_dim: 768
use_checkpoint: true
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: val/rec_loss
embed_dim: 16
ddconfig:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 1
- 2
- 2
- 4
num_res_blocks: 2
attn_resolutions:
- 16
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: torch.nn.Identity
71 changes: 71 additions & 0 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
import kornia


from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test

Expand Down Expand Up @@ -129,3 +133,70 @@ def forward(self,x):

def encode(self, x):
return self(x)


class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
self.normalize = normalize

def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False

def forward(self, text):
tokens = clip.tokenize(text).to(self.device)
z = self.model.encode_text(tokens)
if self.normalize:
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
return z

def encode(self, text):
z = self(text)
if z.ndim==2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z


class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)

self.antialias = antialias

self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)

def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x

def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))

Loading