diff --git a/README.md b/README.md index 3738ddee5..7bdf52a17 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,9 @@

## News + +### July 2022 +- Inference code and model weights to run our [retrieval-augmented diffusion models](https://arxiv.org/abs/2204.11824) are now available. See [this section](#retrieval-augmented-diffusion-models). ### April 2022 - Thanks to [Katherine Crowson](https://github.com/crowsonkb), classifier-free guidance received a ~2x speedup and the [PLMS sampler](https://arxiv.org/abs/2202.09778) is available. See also [this PR](https://github.com/CompVis/latent-diffusion/pull/51). @@ -42,6 +45,74 @@ conda activate ldm A general list of all available checkpoints is available in via our [model zoo](#model-zoo). If you use any of these models in your work, we are always happy to receive a [citation](#bibtex). +## Retrieval Augmented Diffusion Models +![rdm-figure](assets/rdm-preview.jpg) +We include inference code to run our retrieval-augmented diffusion models (RDMs) as described in [https://arxiv.org/abs/2204.11824](https://arxiv.org/abs/2204.11824). + + +To get started, install the additionally required python packages into your `ldm` environment +```shell script +pip install transformers==4.19.2 scann kornia==0.6.4 torchmetrics==0.6.0 +pip install git+https://github.com/arogozhnikov/einops.git +``` +and download the trained weights (preliminary ceckpoints): + +```bash +mkdir -p models/rdm/rdm768x768/ +wget -O models/rdm/rdm768x768/model.ckpt https://ommer-lab.com/files/rdm/model.ckpt +``` +As these models are conditioned on a set of CLIP image embeddings, our RDMs support different inference modes, +which are described in the following. +#### RDM with text-prompt only (no explicit retrieval needed) +Since CLIP offers a shared image/text feature space, and RDMs learn to cover a neighborhood of a given +example during training, we can directly take a CLIP text embedding of a given prompt and condition on it. +Run this mode via +``` +python scripts/knn2img.py --prompt "a happy bear reading a newspaper, oil on canvas" +``` + +#### RDM with text-to-image retrieval + +To be able to run a RDM conditioned on a text-prompt and additionally images retrieved from this prompt, you will also need to download the corresponding retrieval database. +We provide two distinct databases extracted from the [Openimages-](https://storage.googleapis.com/openimages/web/index.html) and [ArtBench-](https://github.com/liaopeiyuan/artbench) datasets. +Interchanging the databases results in different capabilities of the model as visualized below, although the learned weights are the same in both cases. + +Download the retrieval-databases which contain the retrieval-datasets ([Openimages](https://storage.googleapis.com/openimages/web/index.html) (~11GB) and [ArtBench](https://github.com/liaopeiyuan/artbench) (~82MB)) compressed into CLIP image embeddings: +```bash +mkdir -p data/rdm/retrieval_databases +wget -O data/rdm/retrieval_databases/artbench.zip https://ommer-lab.com/files/rdm/artbench_databases.zip +wget -O data/rdm/retrieval_databases/openimages.zip https://ommer-lab.com/files/rdm/openimages_database.zip +unzip data/rdm/retrieval_databases/artbench.zip -d data/rdm/retrieval_databases/ +unzip data/rdm/retrieval_databases/openimages.zip -d data/rdm/retrieval_databases/ +``` +We also provide trained [ScaNN](https://github.com/google-research/google-research/tree/master/scann) search indices for ArtBench. Download and extract via +```bash +mkdir -p data/rdm/searchers +wget -O data/rdm/searchers/artbench.zip https://ommer-lab.com/files/rdm/artbench_searchers.zip +unzip data/rdm/searchers/artbench.zip -d data/rdm/searchers +``` + +Since the index for OpenImages is large (~21 GB), we provide a script to create and save it for usage during sampling. Note however, +that sampling with the OpenImages database will not be possible without this index. Run the script via +```bash +python scripts/train_searcher.py +``` + +Retrieval based text-guided sampling with visual nearest neighbors can be started via +``` +python scripts/knn2img.py --prompt "a happy pineapple" --use_neighbors --knn +``` +Note that the maximum supported number of neighbors is 20. +The database can be changed via the cmd parameter ``--database`` which can be `[openimages, artbench-art_nouveau, artbench-baroque, artbench-expressionism, artbench-impressionism, artbench-post_impressionism, artbench-realism, artbench-renaissance, artbench-romanticism, artbench-surrealism, artbench-ukiyo_e]`. +For using `--database openimages`, the above script (`scripts/train_searcher.py`) must be executed before. +Due to their relatively small size, the artbench datasetbases are best suited for creating more abstract concepts and do not work well for detailed text control. + + +#### Coming Soon +- better models +- more resolutions +- image-to-image retrieval + ## Text-to-Image ![text2img-figure](assets/txt2img-preview.png) @@ -273,6 +344,19 @@ Thanks for open-sourcing! archivePrefix={arXiv}, primaryClass={cs.CV} } + +@misc{https://doi.org/10.48550/arxiv.2204.11824, + doi = {10.48550/ARXIV.2204.11824}, + url = {https://arxiv.org/abs/2204.11824}, + author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and Ommer, Björn}, + keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Retrieval-Augmented Diffusion Models}, + publisher = {arXiv}, + year = {2022}, + copyright = {arXiv.org perpetual, non-exclusive license} +} + + ``` diff --git a/assets/rdm-preview.jpg b/assets/rdm-preview.jpg new file mode 100644 index 000000000..3838b0f6b Binary files /dev/null and b/assets/rdm-preview.jpg differ diff --git a/configs/retrieval-augmented-diffusion/768x768.yaml b/configs/retrieval-augmented-diffusion/768x768.yaml new file mode 100644 index 000000000..b51b1d837 --- /dev/null +++ b/configs/retrieval-augmented-diffusion/768x768.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 0.0001 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.015 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: jpg + cond_stage_key: nix + image_size: 48 + channels: 16 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_by_std: false + scale_factor: 0.22765929 + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 48 + in_channels: 16 + out_channels: 16 + model_channels: 448 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + use_scale_shift_norm: false + resblock_updown: false + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 768 + use_checkpoint: true + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + monitor: val/rec_loss + embed_dim: 16 + ddconfig: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 16 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: torch.nn.Identity \ No newline at end of file diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 506e39ba4..aa3031df6 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,6 +1,10 @@ import torch import torch.nn as nn from functools import partial +import clip +from einops import rearrange, repeat +import kornia + from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test @@ -129,3 +133,70 @@ def forward(self,x): def encode(self, x): return self(x) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + diff --git a/ldm/util.py b/ldm/util.py index 51839cb14..8ba38853e 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -2,6 +2,13 @@ import torch import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont @@ -38,7 +45,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -64,7 +71,7 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params @@ -83,4 +90,114 @@ def get_obj_from_str(string, reload=False): if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) \ No newline at end of file + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/scripts/knn2img.py b/scripts/knn2img.py new file mode 100644 index 000000000..e6eaaecab --- /dev/null +++ b/scripts/knn2img.py @@ -0,0 +1,398 @@ +import argparse, os, sys, glob +import clip +import torch +import torch.nn as nn +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import scann +import time +from multiprocessing import cpu_count + +from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder + +DATABASES = [ + "openimages", + "artbench-art_nouveau", + "artbench-baroque", + "artbench-expressionism", + "artbench-impressionism", + "artbench-post_impressionism", + "artbench-realism", + "artbench-romanticism", + "artbench-renaissance", + "artbench-surrealism", + "artbench-ukiyo_e", +] + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): + assert database in DATABASES + # self.database = self.load_database(database) + self.database_name = database + self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' + self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.retriever = self.load_retriever(version=retriever_version) + self.database = {'embedding': [], + 'img_id': [], + 'patch_coords': []} + self.load_database() + self.load_searcher() + + def train_searcher(self, k, + metric='dot_product', + searcher_savedir=None): + + print('Start training searcher') + searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / + np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], + k, metric) + self.searcher = searcher.score_brute_force().build() + print('Finish training searcher') + + if searcher_savedir is not None: + print(f'Save trained searcher under "{searcher_savedir}"') + os.makedirs(searcher_savedir, exist_ok=True) + self.searcher.serialize(searcher_savedir) + + def load_single_file(self, saved_embeddings): + compressed = np.load(saved_embeddings) + self.database = {key: compressed[key] for key in compressed.files} + print('Finished loading of clip embeddings.') + + def load_multi_files(self, data_archive): + out_data = {key: [] for key in self.database} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + out_data[key].append(d[key]) + + return out_data + + def load_database(self): + + print(f'Load saved patch embedding from "{self.database_path}"') + file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + + if len(file_content) == 1: + self.load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(self.load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in + self.database} + else: + raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') + + def load_retriever(self, version='ViT-L/14', ): + model = FrozenClipImageEmbedder(model=version) + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + def load_searcher(self): + print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) + print('Finished loading searcher.') + + def search(self, x, k): + if self.searcher is None and self.database['embedding'].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + if len(x.shape) == 3: + x = x[:, 0] + query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis] + + start = time.time() + nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) + end = time.time() + + out_embeddings = self.database['embedding'][nns] + out_img_ids = self.database['img_id'][nns] + out_pc = self.database['patch_coords'][nns] + + out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings} + + return out + + def __call__(self, x, n): + return self.search(x, n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) + # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + + parser.add_argument( + "--n_repeat", + type=int, + default=1, + help="number of repeats in CLIP latent space", + ) + + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", + ) + + parser.add_argument( + "--H", + type=int, + default=768, + help="image height, in pixel space", + ) + + parser.add_argument( + "--W", + type=int, + default=768, + help="image width, in pixel space", + ) + + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a batch size", + ) + + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + + parser.add_argument( + "--scale", + type=float, + default=5.0, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="configs/retrieval-augmented-diffusion/768x768.yaml", + help="path to config which constructs model", + ) + + parser.add_argument( + "--ckpt", + type=str, + default="models/rdm/rdm768x768/model.ckpt", + help="path to checkpoint of model", + ) + + parser.add_argument( + "--clip_type", + type=str, + default="ViT-L/14", + help="which CLIP model to use for retrieval and NN encoding", + ) + parser.add_argument( + "--database", + type=str, + default='artbench-surrealism', + choices=DATABASES, + help="The database used for the search, only applied when --use_neighbors=True", + ) + parser.add_argument( + "--use_neighbors", + default=False, + action='store_true', + help="Include neighbors in addition to text prompt for conditioning", + ) + parser.add_argument( + "--knn", + default=10, + type=int, + help="The number of included neighbors, only applied when --use_neighbors=True", + ) + + opt = parser.parse_args() + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + print(f"sampling scale for cfg is {opt.scale:.2f}") + + searcher = None + if opt.use_neighbors: + searcher = Searcher(opt.database) + + with torch.no_grad(): + with model.ema_scope(): + for n in trange(opt.n_iter, desc="Sampling"): + all_samples = list() + for prompts in tqdm(data, desc="data"): + print("sampling prompts:", prompts) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = clip_text_encoder.encode(prompts) + uc = None + if searcher is not None: + nn_dict = searcher(c, opt.knn) + c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + if opt.scale != 1.0: + uc = torch.zeros_like(c) + if isinstance(prompts, tuple): + prompts = list(prompts) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/scripts/train_searcher.py b/scripts/train_searcher.py new file mode 100644 index 000000000..1e7904889 --- /dev/null +++ b/scripts/train_searcher.py @@ -0,0 +1,147 @@ +import os, sys +import numpy as np +import scann +import argparse +import glob +from multiprocessing import cpu_count +from tqdm import tqdm + +from ldm.util import parallel_data_prefetch + + +def search_bruteforce(searcher): + return searcher.score_brute_force().build() + + +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search): + return searcher.tree(num_leaves=num_leaves, + num_leaves_to_search=num_leaves_to_search, + training_sample_size=partioning_trainsize). \ + score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + + +def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): + return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( + reorder_k).build() + +def load_datapool(dpath): + + + def load_single_file(saved_embeddings): + compressed = np.load(saved_embeddings) + database = {key: compressed[key] for key in compressed.files} + return database + + def load_multi_files(data_archive): + database = {key: [] for key in data_archive[0].files} + for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for key in d.files: + database[key].append(d[key]) + + return database + + print(f'Load saved patch embedding from "{dpath}"') + file_content = glob.glob(os.path.join(dpath, '*.npz')) + + if len(file_content) == 1: + data_pool = load_single_file(file_content[0]) + elif len(file_content) > 1: + data = [np.load(f) for f in file_content] + prefetched_data = parallel_data_prefetch(load_multi_files, data, + n_proc=min(len(data), cpu_count()), target_data_type='dict') + + data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + else: + raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') + + print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') + return data_pool + + +def train_searcher(opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None,): + + data_pool = load_datapool(opt.database) + k = opt.knn + + if not reorder_k: + reorder_k = 2 * k + + # normalize + # embeddings = + searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + pool_size = data_pool['embedding'].shape[0] + + print(*(['#'] * 100)) + print('Initializing scaNN searcher with the following values:') + print(f'k: {k}') + print(f'metric: {metric}') + print(f'reorder_k: {reorder_k}') + print(f'anisotropic_quantization_threshold: {aiq_thld}') + print(f'dims_per_block: {dims_per_block}') + print(*(['#'] * 100)) + print('Start training searcher....') + print(f'N samples in pool is {pool_size}') + + # this reflects the recommended design choices proposed at + # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md + if pool_size < 2e4: + print('Using brute force search.') + searcher = search_bruteforce(searcher) + elif 2e4 <= pool_size and pool_size < 1e5: + print('Using asymmetric hashing search and reordering.') + searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + else: + print('Using using partioning, asymmetric hashing search and reordering.') + + if not partioning_trainsize: + partioning_trainsize = data_pool['embedding'].shape[0] // 10 + if not num_leaves: + num_leaves = int(np.sqrt(pool_size)) + + if not num_leaves_to_search: + num_leaves_to_search = max(num_leaves // 20, 1) + + print('Partitioning params:') + print(f'num_leaves: {num_leaves}') + print(f'num_leaves_to_search: {num_leaves_to_search}') + # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, + partioning_trainsize, num_leaves, num_leaves_to_search) + + print('Finish training searcher') + searcher_savedir = opt.target_path + os.makedirs(searcher_savedir, exist_ok=True) + searcher.serialize(searcher_savedir) + print(f'Saved trained searcher under "{searcher_savedir}"') + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--database', + '-d', + default='data/rdm/retrieval_databases/openimages', + type=str, + help='path to folder containing the clip feature of the database') + parser.add_argument('--target_path', + '-t', + default='data/rdm/searchers/openimages', + type=str, + help='path to the target folder where the searcher shall be stored.') + parser.add_argument('--knn', + '-k', + default=20, + type=int, + help='number of nearest neighbors, for which the searcher shall be optimized') + + opt, _ = parser.parse_known_args() + + train_searcher(opt,) \ No newline at end of file