diff --git a/README.md b/README.md
index 3738ddee5..7bdf52a17 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,9 @@
## News
+
+### July 2022
+- Inference code and model weights to run our [retrieval-augmented diffusion models](https://arxiv.org/abs/2204.11824) are now available. See [this section](#retrieval-augmented-diffusion-models).
### April 2022
- Thanks to [Katherine Crowson](https://github.com/crowsonkb), classifier-free guidance received a ~2x speedup and the [PLMS sampler](https://arxiv.org/abs/2202.09778) is available. See also [this PR](https://github.com/CompVis/latent-diffusion/pull/51).
@@ -42,6 +45,74 @@ conda activate ldm
A general list of all available checkpoints is available in via our [model zoo](#model-zoo).
If you use any of these models in your work, we are always happy to receive a [citation](#bibtex).
+## Retrieval Augmented Diffusion Models
+
+We include inference code to run our retrieval-augmented diffusion models (RDMs) as described in [https://arxiv.org/abs/2204.11824](https://arxiv.org/abs/2204.11824).
+
+
+To get started, install the additionally required python packages into your `ldm` environment
+```shell script
+pip install transformers==4.19.2 scann kornia==0.6.4 torchmetrics==0.6.0
+pip install git+https://github.com/arogozhnikov/einops.git
+```
+and download the trained weights (preliminary ceckpoints):
+
+```bash
+mkdir -p models/rdm/rdm768x768/
+wget -O models/rdm/rdm768x768/model.ckpt https://ommer-lab.com/files/rdm/model.ckpt
+```
+As these models are conditioned on a set of CLIP image embeddings, our RDMs support different inference modes,
+which are described in the following.
+#### RDM with text-prompt only (no explicit retrieval needed)
+Since CLIP offers a shared image/text feature space, and RDMs learn to cover a neighborhood of a given
+example during training, we can directly take a CLIP text embedding of a given prompt and condition on it.
+Run this mode via
+```
+python scripts/knn2img.py --prompt "a happy bear reading a newspaper, oil on canvas"
+```
+
+#### RDM with text-to-image retrieval
+
+To be able to run a RDM conditioned on a text-prompt and additionally images retrieved from this prompt, you will also need to download the corresponding retrieval database.
+We provide two distinct databases extracted from the [Openimages-](https://storage.googleapis.com/openimages/web/index.html) and [ArtBench-](https://github.com/liaopeiyuan/artbench) datasets.
+Interchanging the databases results in different capabilities of the model as visualized below, although the learned weights are the same in both cases.
+
+Download the retrieval-databases which contain the retrieval-datasets ([Openimages](https://storage.googleapis.com/openimages/web/index.html) (~11GB) and [ArtBench](https://github.com/liaopeiyuan/artbench) (~82MB)) compressed into CLIP image embeddings:
+```bash
+mkdir -p data/rdm/retrieval_databases
+wget -O data/rdm/retrieval_databases/artbench.zip https://ommer-lab.com/files/rdm/artbench_databases.zip
+wget -O data/rdm/retrieval_databases/openimages.zip https://ommer-lab.com/files/rdm/openimages_database.zip
+unzip data/rdm/retrieval_databases/artbench.zip -d data/rdm/retrieval_databases/
+unzip data/rdm/retrieval_databases/openimages.zip -d data/rdm/retrieval_databases/
+```
+We also provide trained [ScaNN](https://github.com/google-research/google-research/tree/master/scann) search indices for ArtBench. Download and extract via
+```bash
+mkdir -p data/rdm/searchers
+wget -O data/rdm/searchers/artbench.zip https://ommer-lab.com/files/rdm/artbench_searchers.zip
+unzip data/rdm/searchers/artbench.zip -d data/rdm/searchers
+```
+
+Since the index for OpenImages is large (~21 GB), we provide a script to create and save it for usage during sampling. Note however,
+that sampling with the OpenImages database will not be possible without this index. Run the script via
+```bash
+python scripts/train_searcher.py
+```
+
+Retrieval based text-guided sampling with visual nearest neighbors can be started via
+```
+python scripts/knn2img.py --prompt "a happy pineapple" --use_neighbors --knn
+```
+Note that the maximum supported number of neighbors is 20.
+The database can be changed via the cmd parameter ``--database`` which can be `[openimages, artbench-art_nouveau, artbench-baroque, artbench-expressionism, artbench-impressionism, artbench-post_impressionism, artbench-realism, artbench-renaissance, artbench-romanticism, artbench-surrealism, artbench-ukiyo_e]`.
+For using `--database openimages`, the above script (`scripts/train_searcher.py`) must be executed before.
+Due to their relatively small size, the artbench datasetbases are best suited for creating more abstract concepts and do not work well for detailed text control.
+
+
+#### Coming Soon
+- better models
+- more resolutions
+- image-to-image retrieval
+
## Text-to-Image

@@ -273,6 +344,19 @@ Thanks for open-sourcing!
archivePrefix={arXiv},
primaryClass={cs.CV}
}
+
+@misc{https://doi.org/10.48550/arxiv.2204.11824,
+ doi = {10.48550/ARXIV.2204.11824},
+ url = {https://arxiv.org/abs/2204.11824},
+ author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and Ommer, Björn},
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {Retrieval-Augmented Diffusion Models},
+ publisher = {arXiv},
+ year = {2022},
+ copyright = {arXiv.org perpetual, non-exclusive license}
+}
+
+
```
diff --git a/assets/rdm-preview.jpg b/assets/rdm-preview.jpg
new file mode 100644
index 000000000..3838b0f6b
Binary files /dev/null and b/assets/rdm-preview.jpg differ
diff --git a/configs/retrieval-augmented-diffusion/768x768.yaml b/configs/retrieval-augmented-diffusion/768x768.yaml
new file mode 100644
index 000000000..b51b1d837
--- /dev/null
+++ b/configs/retrieval-augmented-diffusion/768x768.yaml
@@ -0,0 +1,68 @@
+model:
+ base_learning_rate: 0.0001
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.0015
+ linear_end: 0.015
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: jpg
+ cond_stage_key: nix
+ image_size: 48
+ channels: 16
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_by_std: false
+ scale_factor: 0.22765929
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 48
+ in_channels: 16
+ out_channels: 16
+ model_channels: 448
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 3
+ - 4
+ use_scale_shift_norm: false
+ resblock_updown: false
+ num_head_channels: 32
+ use_spatial_transformer: true
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: true
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ monitor: val/rec_loss
+ embed_dim: 16
+ ddconfig:
+ double_z: true
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+ cond_stage_config:
+ target: torch.nn.Identity
\ No newline at end of file
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
index 506e39ba4..aa3031df6 100644
--- a/ldm/modules/encoders/modules.py
+++ b/ldm/modules/encoders/modules.py
@@ -1,6 +1,10 @@
import torch
import torch.nn as nn
from functools import partial
+import clip
+from einops import rearrange, repeat
+import kornia
+
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
@@ -129,3 +133,70 @@ def forward(self,x):
def encode(self, x):
return self(x)
+
+
+class FrozenCLIPTextEmbedder(nn.Module):
+ """
+ Uses the CLIP transformer encoder for text.
+ """
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
+ super().__init__()
+ self.model, _ = clip.load(version, jit=False, device="cpu")
+ self.device = device
+ self.max_length = max_length
+ self.n_repeat = n_repeat
+ self.normalize = normalize
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = clip.tokenize(text).to(self.device)
+ z = self.model.encode_text(tokens)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ return z
+
+ def encode(self, text):
+ z = self(text)
+ if z.ndim==2:
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ return z
+
+
+class FrozenClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ return self.model.encode_image(self.preprocess(x))
+
diff --git a/ldm/util.py b/ldm/util.py
index 51839cb14..8ba38853e 100644
--- a/ldm/util.py
+++ b/ldm/util.py
@@ -2,6 +2,13 @@
import torch
import numpy as np
+from collections import abc
+from einops import rearrange
+from functools import partial
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
@@ -38,7 +45,7 @@ def ismap(x):
def isimage(x):
- if not isinstance(x,torch.Tensor):
+ if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
@@ -64,7 +71,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
@@ -83,4 +90,114 @@ def get_obj_from_str(string, reload=False):
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
\ No newline at end of file
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+ # create dummy dataset instance
+
+ # run prefetching
+ if idx_to_fn:
+ res = func(data, worker_id=idx)
+ else:
+ res = func(data)
+ Q.put([idx, res])
+ Q.put("Done")
+
+
+def parallel_data_prefetch(
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
+):
+ # if target_data_type not in ["ndarray", "list"]:
+ # raise ValueError(
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+ # )
+ if isinstance(data, np.ndarray) and target_data_type == "list":
+ raise ValueError("list expected but function got ndarray.")
+ elif isinstance(data, abc.Iterable):
+ if isinstance(data, dict):
+ print(
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+ )
+ data = list(data.values())
+ if target_data_type == "ndarray":
+ data = np.asarray(data)
+ else:
+ data = list(data)
+ else:
+ raise TypeError(
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+ )
+
+ if cpu_intensive:
+ Q = mp.Queue(1000)
+ proc = mp.Process
+ else:
+ Q = Queue(1000)
+ proc = Thread
+ # spawn processes
+ if target_data_type == "ndarray":
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(np.array_split(data, n_proc))
+ ]
+ else:
+ step = (
+ int(len(data) / n_proc + 1)
+ if len(data) % n_proc != 0
+ else int(len(data) / n_proc)
+ )
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(
+ [data[i: i + step] for i in range(0, len(data), step)]
+ )
+ ]
+ processes = []
+ for i in range(n_proc):
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+ processes += [p]
+
+ # start processes
+ print(f"Start prefetching...")
+ import time
+
+ start = time.time()
+ gather_res = [[] for _ in range(n_proc)]
+ try:
+ for p in processes:
+ p.start()
+
+ k = 0
+ while k < n_proc:
+ # get result
+ res = Q.get()
+ if res == "Done":
+ k += 1
+ else:
+ gather_res[res[0]] = res[1]
+
+ except Exception as e:
+ print("Exception: ", e)
+ for p in processes:
+ p.terminate()
+
+ raise e
+ finally:
+ for p in processes:
+ p.join()
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+ if target_data_type == 'ndarray':
+ if not isinstance(gather_res[0], np.ndarray):
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+ # order outputs
+ return np.concatenate(gather_res, axis=0)
+ elif target_data_type == 'list':
+ out = []
+ for r in gather_res:
+ out.extend(r)
+ return out
+ else:
+ return gather_res
diff --git a/scripts/knn2img.py b/scripts/knn2img.py
new file mode 100644
index 000000000..e6eaaecab
--- /dev/null
+++ b/scripts/knn2img.py
@@ -0,0 +1,398 @@
+import argparse, os, sys, glob
+import clip
+import torch
+import torch.nn as nn
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+import scann
+import time
+from multiprocessing import cpu_count
+
+from ldm.util import instantiate_from_config, parallel_data_prefetch
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
+
+DATABASES = [
+ "openimages",
+ "artbench-art_nouveau",
+ "artbench-baroque",
+ "artbench-expressionism",
+ "artbench-impressionism",
+ "artbench-post_impressionism",
+ "artbench-realism",
+ "artbench-romanticism",
+ "artbench-renaissance",
+ "artbench-surrealism",
+ "artbench-ukiyo_e",
+]
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+class Searcher(object):
+ def __init__(self, database, retriever_version='ViT-L/14'):
+ assert database in DATABASES
+ # self.database = self.load_database(database)
+ self.database_name = database
+ self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
+ self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
+ self.retriever = self.load_retriever(version=retriever_version)
+ self.database = {'embedding': [],
+ 'img_id': [],
+ 'patch_coords': []}
+ self.load_database()
+ self.load_searcher()
+
+ def train_searcher(self, k,
+ metric='dot_product',
+ searcher_savedir=None):
+
+ print('Start training searcher')
+ searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
+ np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
+ k, metric)
+ self.searcher = searcher.score_brute_force().build()
+ print('Finish training searcher')
+
+ if searcher_savedir is not None:
+ print(f'Save trained searcher under "{searcher_savedir}"')
+ os.makedirs(searcher_savedir, exist_ok=True)
+ self.searcher.serialize(searcher_savedir)
+
+ def load_single_file(self, saved_embeddings):
+ compressed = np.load(saved_embeddings)
+ self.database = {key: compressed[key] for key in compressed.files}
+ print('Finished loading of clip embeddings.')
+
+ def load_multi_files(self, data_archive):
+ out_data = {key: [] for key in self.database}
+ for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
+ for key in d.files:
+ out_data[key].append(d[key])
+
+ return out_data
+
+ def load_database(self):
+
+ print(f'Load saved patch embedding from "{self.database_path}"')
+ file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
+
+ if len(file_content) == 1:
+ self.load_single_file(file_content[0])
+ elif len(file_content) > 1:
+ data = [np.load(f) for f in file_content]
+ prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
+ n_proc=min(len(data), cpu_count()), target_data_type='dict')
+
+ self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
+ self.database}
+ else:
+ raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
+
+ print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
+
+ def load_retriever(self, version='ViT-L/14', ):
+ model = FrozenClipImageEmbedder(model=version)
+ if torch.cuda.is_available():
+ model.cuda()
+ model.eval()
+ return model
+
+ def load_searcher(self):
+ print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
+ self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
+ print('Finished loading searcher.')
+
+ def search(self, x, k):
+ if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
+ self.train_searcher(k) # quickly fit searcher on the fly for small databases
+ assert self.searcher is not None, 'Cannot search with uninitialized searcher'
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ if len(x.shape) == 3:
+ x = x[:, 0]
+ query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
+
+ start = time.time()
+ nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
+ end = time.time()
+
+ out_embeddings = self.database['embedding'][nns]
+ out_img_ids = self.database['img_id'][nns]
+ out_pc = self.database['patch_coords'][nns]
+
+ out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
+ 'img_ids': out_img_ids,
+ 'patch_coords': out_pc,
+ 'queries': x,
+ 'exec_time': end - start,
+ 'nns': nns,
+ 'q_embeddings': query_embeddings}
+
+ return out
+
+ def __call__(self, x, n):
+ return self.search(x, n)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
+ # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ nargs="?",
+ default="a painting of a virus monster playing guitar",
+ help="the prompt to render"
+ )
+
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ default="outputs/txt2img-samples"
+ )
+
+ parser.add_argument(
+ "--skip_grid",
+ action='store_true',
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
+ )
+
+ parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+ )
+
+ parser.add_argument(
+ "--n_repeat",
+ type=int,
+ default=1,
+ help="number of repeats in CLIP latent space",
+ )
+
+ parser.add_argument(
+ "--plms",
+ action='store_true',
+ help="use plms sampling",
+ )
+
+ parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+ )
+ parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=1,
+ help="sample this often",
+ )
+
+ parser.add_argument(
+ "--H",
+ type=int,
+ default=768,
+ help="image height, in pixel space",
+ )
+
+ parser.add_argument(
+ "--W",
+ type=int,
+ default=768,
+ help="image width, in pixel space",
+ )
+
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=3,
+ help="how many samples to produce for each given prompt. A.k.a batch size",
+ )
+
+ parser.add_argument(
+ "--n_rows",
+ type=int,
+ default=0,
+ help="rows in the grid (default: n_samples)",
+ )
+
+ parser.add_argument(
+ "--scale",
+ type=float,
+ default=5.0,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+ )
+
+ parser.add_argument(
+ "--from-file",
+ type=str,
+ help="if specified, load prompts from this file",
+ )
+
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/retrieval-augmented-diffusion/768x768.yaml",
+ help="path to config which constructs model",
+ )
+
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="models/rdm/rdm768x768/model.ckpt",
+ help="path to checkpoint of model",
+ )
+
+ parser.add_argument(
+ "--clip_type",
+ type=str,
+ default="ViT-L/14",
+ help="which CLIP model to use for retrieval and NN encoding",
+ )
+ parser.add_argument(
+ "--database",
+ type=str,
+ default='artbench-surrealism',
+ choices=DATABASES,
+ help="The database used for the search, only applied when --use_neighbors=True",
+ )
+ parser.add_argument(
+ "--use_neighbors",
+ default=False,
+ action='store_true',
+ help="Include neighbors in addition to text prompt for conditioning",
+ )
+ parser.add_argument(
+ "--knn",
+ default=10,
+ type=int,
+ help="The number of included neighbors, only applied when --use_neighbors=True",
+ )
+
+ opt = parser.parse_args()
+
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, f"{opt.ckpt}")
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
+
+ if opt.plms:
+ sampler = PLMSSampler(model)
+ else:
+ sampler = DDIMSampler(model)
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ outpath = opt.outdir
+
+ batch_size = opt.n_samples
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+ if not opt.from_file:
+ prompt = opt.prompt
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+
+ else:
+ print(f"reading prompts from {opt.from_file}")
+ with open(opt.from_file, "r") as f:
+ data = f.read().splitlines()
+ data = list(chunk(data, batch_size))
+
+ sample_path = os.path.join(outpath, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+ grid_count = len(os.listdir(outpath)) - 1
+
+ print(f"sampling scale for cfg is {opt.scale:.2f}")
+
+ searcher = None
+ if opt.use_neighbors:
+ searcher = Searcher(opt.database)
+
+ with torch.no_grad():
+ with model.ema_scope():
+ for n in trange(opt.n_iter, desc="Sampling"):
+ all_samples = list()
+ for prompts in tqdm(data, desc="data"):
+ print("sampling prompts:", prompts)
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = clip_text_encoder.encode(prompts)
+ uc = None
+ if searcher is not None:
+ nn_dict = searcher(c, opt.knn)
+ c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
+ if opt.scale != 1.0:
+ uc = torch.zeros_like(c)
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
+ conditioning=c,
+ batch_size=c.shape[0],
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc,
+ eta=opt.ddim_eta,
+ )
+
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for x_sample in x_samples_ddim:
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(sample_path, f"{base_count:05}.png"))
+ base_count += 1
+ all_samples.append(x_samples_ddim)
+
+ if not opt.skip_grid:
+ # additionally, save as grid
+ grid = torch.stack(all_samples, 0)
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
+ grid = make_grid(grid, nrow=n_rows)
+
+ # to image
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
+ grid_count += 1
+
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
diff --git a/scripts/train_searcher.py b/scripts/train_searcher.py
new file mode 100644
index 000000000..1e7904889
--- /dev/null
+++ b/scripts/train_searcher.py
@@ -0,0 +1,147 @@
+import os, sys
+import numpy as np
+import scann
+import argparse
+import glob
+from multiprocessing import cpu_count
+from tqdm import tqdm
+
+from ldm.util import parallel_data_prefetch
+
+
+def search_bruteforce(searcher):
+ return searcher.score_brute_force().build()
+
+
+def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
+ partioning_trainsize, num_leaves, num_leaves_to_search):
+ return searcher.tree(num_leaves=num_leaves,
+ num_leaves_to_search=num_leaves_to_search,
+ training_sample_size=partioning_trainsize). \
+ score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
+
+
+def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
+ return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
+ reorder_k).build()
+
+def load_datapool(dpath):
+
+
+ def load_single_file(saved_embeddings):
+ compressed = np.load(saved_embeddings)
+ database = {key: compressed[key] for key in compressed.files}
+ return database
+
+ def load_multi_files(data_archive):
+ database = {key: [] for key in data_archive[0].files}
+ for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
+ for key in d.files:
+ database[key].append(d[key])
+
+ return database
+
+ print(f'Load saved patch embedding from "{dpath}"')
+ file_content = glob.glob(os.path.join(dpath, '*.npz'))
+
+ if len(file_content) == 1:
+ data_pool = load_single_file(file_content[0])
+ elif len(file_content) > 1:
+ data = [np.load(f) for f in file_content]
+ prefetched_data = parallel_data_prefetch(load_multi_files, data,
+ n_proc=min(len(data), cpu_count()), target_data_type='dict')
+
+ data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
+ else:
+ raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
+
+ print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
+ return data_pool
+
+
+def train_searcher(opt,
+ metric='dot_product',
+ partioning_trainsize=None,
+ reorder_k=None,
+ # todo tune
+ aiq_thld=0.2,
+ dims_per_block=2,
+ num_leaves=None,
+ num_leaves_to_search=None,):
+
+ data_pool = load_datapool(opt.database)
+ k = opt.knn
+
+ if not reorder_k:
+ reorder_k = 2 * k
+
+ # normalize
+ # embeddings =
+ searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
+ pool_size = data_pool['embedding'].shape[0]
+
+ print(*(['#'] * 100))
+ print('Initializing scaNN searcher with the following values:')
+ print(f'k: {k}')
+ print(f'metric: {metric}')
+ print(f'reorder_k: {reorder_k}')
+ print(f'anisotropic_quantization_threshold: {aiq_thld}')
+ print(f'dims_per_block: {dims_per_block}')
+ print(*(['#'] * 100))
+ print('Start training searcher....')
+ print(f'N samples in pool is {pool_size}')
+
+ # this reflects the recommended design choices proposed at
+ # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
+ if pool_size < 2e4:
+ print('Using brute force search.')
+ searcher = search_bruteforce(searcher)
+ elif 2e4 <= pool_size and pool_size < 1e5:
+ print('Using asymmetric hashing search and reordering.')
+ searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
+ else:
+ print('Using using partioning, asymmetric hashing search and reordering.')
+
+ if not partioning_trainsize:
+ partioning_trainsize = data_pool['embedding'].shape[0] // 10
+ if not num_leaves:
+ num_leaves = int(np.sqrt(pool_size))
+
+ if not num_leaves_to_search:
+ num_leaves_to_search = max(num_leaves // 20, 1)
+
+ print('Partitioning params:')
+ print(f'num_leaves: {num_leaves}')
+ print(f'num_leaves_to_search: {num_leaves_to_search}')
+ # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
+ searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
+ partioning_trainsize, num_leaves, num_leaves_to_search)
+
+ print('Finish training searcher')
+ searcher_savedir = opt.target_path
+ os.makedirs(searcher_savedir, exist_ok=True)
+ searcher.serialize(searcher_savedir)
+ print(f'Saved trained searcher under "{searcher_savedir}"')
+
+if __name__ == '__main__':
+ sys.path.append(os.getcwd())
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--database',
+ '-d',
+ default='data/rdm/retrieval_databases/openimages',
+ type=str,
+ help='path to folder containing the clip feature of the database')
+ parser.add_argument('--target_path',
+ '-t',
+ default='data/rdm/searchers/openimages',
+ type=str,
+ help='path to the target folder where the searcher shall be stored.')
+ parser.add_argument('--knn',
+ '-k',
+ default=20,
+ type=int,
+ help='number of nearest neighbors, for which the searcher shall be optimized')
+
+ opt, _ = parser.parse_known_args()
+
+ train_searcher(opt,)
\ No newline at end of file