Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
24c7052
(hacky) allow attention_head_dim as Tuple
patil-suraj Nov 7, 2022
87e61c8
add pipeline
patil-suraj Nov 7, 2022
36feb3a
add in init
patil-suraj Nov 7, 2022
686d924
add in init
patil-suraj Nov 7, 2022
e96d564
fix elcip encode, default args
patil-suraj Nov 7, 2022
b09b8d7
allow passing retrieved_images
patil-suraj Nov 7, 2022
a502823
remove negative prompt
patil-suraj Nov 7, 2022
a04ca64
fix docs
patil-suraj Nov 7, 2022
86ea213
Starting retriever
isamu-isozaki Dec 27, 2022
db3f2ec
Adding retriever class and jupyter notebook for test
isamu-isozaki Dec 28, 2022
d4dfad9
Modified nb
isamu-isozaki Dec 28, 2022
19fc4c4
Made retriever!
isamu-isozaki Dec 29, 2022
ff40a90
Began working on training script
isamu-isozaki Dec 30, 2022
8d68c9c
Removed wandb
isamu-isozaki Jan 1, 2023
2c1114b
Removed unrelavant files
isamu-isozaki Jan 1, 2023
9645bb4
Merged with main
isamu-isozaki Jan 1, 2023
e6289fe
Merged with main
isamu-isozaki Jan 1, 2023
faa57b6
Started writing training script and adding clip-retrieval
isamu-isozaki Jan 2, 2023
e7679c8
Training setup
isamu-isozaki Jan 6, 2023
e921cea
Setup clip-retrieval training
isamu-isozaki Jan 7, 2023
2da8fb3
Attempt training
isamu-isozaki Jan 7, 2023
f4a33c1
Fixing precision error
isamu-isozaki Jan 7, 2023
097515a
Training fix up
isamu-isozaki Jan 13, 2023
651546c
Got config of model to fit in 6gb
isamu-isozaki Jan 16, 2023
b975fea
Adding run instructions
isamu-isozaki Jan 16, 2023
25377eb
Basic code for logging
isamu-isozaki Jan 16, 2023
20e3e5f
Added logging functionality
isamu-isozaki Jan 16, 2023
1700925
keep unet in fp32
isamu-isozaki Jan 17, 2023
5695768
Added revision dtype
isamu-isozaki Jan 17, 2023
1064168
Allow for multiple batch size
isamu-isozaki Jan 17, 2023
fc2abb0
Fixed concat bug
isamu-isozaki Jan 17, 2023
54f7664
Fixed concat bug #2
isamu-isozaki Jan 17, 2023
497b1ba
Added support for clip to fp16 embeddings
isamu-isozaki Jan 17, 2023
74e7e36
Hot fix vae scaling factor
isamu-isozaki Jan 17, 2023
5ca4321
Made basic colossalai script
isamu-isozaki Jan 17, 2023
b6bbde2
Merge branch 'main' of https://github.com/huggingface/diffusers into …
isamu-isozaki Jan 17, 2023
35d0eaf
Added colossalai
isamu-isozaki Jan 17, 2023
49232f9
Updated placement
isamu-isozaki Jan 17, 2023
17fdd53
Added context
isamu-isozaki Jan 17, 2023
dd3efb6
Removed accelerate from colossal
isamu-isozaki Jan 18, 2023
e362d7b
Fixed logger
isamu-isozaki Jan 18, 2023
cee6f05
Update readme
isamu-isozaki Jan 18, 2023
1153aa2
Pipeline fix attempt
isamu-isozaki Jan 18, 2023
09966dc
Pipeline fix attempt 2
isamu-isozaki Jan 18, 2023
1af2106
Logging devices
isamu-isozaki Jan 18, 2023
feb820c
Moving to unet device
isamu-isozaki Jan 18, 2023
bbc4bbd
Changed self.device to self.unet.device
isamu-isozaki Jan 18, 2023
db3391b
logging dtype
isamu-isozaki Jan 18, 2023
0750c26
Fixed bug
isamu-isozaki Jan 18, 2023
87ade1d
Returned autocast
isamu-isozaki Jan 18, 2023
43a332e
Fixed autocast
isamu-isozaki Jan 18, 2023
257562c
Using the none-static unet
isamu-isozaki Jan 18, 2023
da6adc0
Switch back to self.device
isamu-isozaki Jan 18, 2023
6772547
Hotfix device issue
isamu-isozaki Jan 18, 2023
1a59124
Made pipeline for colossalai
isamu-isozaki Jan 18, 2023
a79105f
Added debugging prints
isamu-isozaki Jan 18, 2023
703ca5c
Attempt converting back to normal tensor
isamu-isozaki Jan 18, 2023
ea1c760
Updated gitignore and readme
isamu-isozaki Jan 20, 2023
87430c9
Updated pipeline+started making inference script
isamu-isozaki Jan 22, 2023
ac29f83
Made basic adding code
isamu-isozaki Jan 22, 2023
faa13f1
Fixed typo
isamu-isozaki Jan 22, 2023
31212c6
Updated unet device
isamu-isozaki Jan 22, 2023
bd4c798
Added some images+updated readme
isamu-isozaki Jan 22, 2023
51e7b09
Update readme
isamu-isozaki Jan 23, 2023
4e966c1
Set default dataset to indexed oxford_pets
isamu-isozaki Jan 24, 2023
33d2bc1
Update db in readme
isamu-isozaki Jan 24, 2023
59792e5
Fixed rdm pipeline bug
isamu-isozaki Jan 24, 2023
1169038
Updated install instructions
isamu-isozaki Jan 24, 2023
b412723
Clone tensor for safety
isamu-isozaki Jan 24, 2023
1e6888f
Undid clone
isamu-isozaki Jan 25, 2023
821a76e
Modifying retriever so it's end to end
isamu-isozaki Feb 2, 2023
5511451
Made basic from_pretrained script
isamu-isozaki Feb 2, 2023
d945624
Removed wandb loggings
isamu-isozaki Feb 2, 2023
92fa8c4
Finished main funcs
isamu-isozaki Feb 2, 2023
c1b775a
Residual push
isamu-isozaki Feb 2, 2023
29fb937
Update logic
isamu-isozaki Feb 2, 2023
3c5195e
Made clip use optional
isamu-isozaki Feb 2, 2023
3762f73
Fix typo
isamu-isozaki Feb 2, 2023
539b6ff
Removing tokenizer
isamu-isozaki Feb 2, 2023
edbe294
Made clip more optional
isamu-isozaki Feb 2, 2023
2c15e97
Fixing for tests
isamu-isozaki Feb 3, 2023
7dc4542
Fixed bugs
isamu-isozaki Feb 3, 2023
b5d3125
Fixed wrong index bug
isamu-isozaki Feb 10, 2023
f0e7bf5
Added embedding func for any model
isamu-isozaki Feb 12, 2023
b4dba5a
Cleaned up code a bit
isamu-isozaki Feb 26, 2023
cb13a5c
Make retriever work with a general model
isamu-isozaki Feb 26, 2023
29223d8
Reformatting some code
isamu-isozaki Feb 26, 2023
5d1174e
More clean up
isamu-isozaki Feb 26, 2023
6118e0c
Removed device='cuda' and using model's device instead
isamu-isozaki Feb 26, 2023
0de7534
Removed clip retrieval+cleaning up inference script
isamu-isozaki Feb 26, 2023
d124f38
Fixed up inference script
isamu-isozaki Feb 27, 2023
80a677a
Cleaned up scripts a bit
isamu-isozaki Feb 27, 2023
72b0412
Removing frida pics altho they are cute
isamu-isozaki Mar 2, 2023
7643056
Removed training scripts+colossalai pipeline
isamu-isozaki Mar 22, 2023
e728f98
Ran black
isamu-isozaki Mar 22, 2023
a6dc7b5
Remove example
isamu-isozaki Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
RDMPipeline,
)

try:
Expand Down
234 changes: 234 additions & 0 deletions src/diffusers/models/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""
Idea for structure
Retriever aggregates an Index class and a RetrieverConfig class
The Index class aggregates a Dataset and RetrieverConfig class
from_pretrained in the retriever's class, it takes in a huggingface path to a dataset, optional path to an index file+config file in huggingface if there is one
If an index file is provided, add that index to the dataset.
If the dataset doesn't have the column embedding or a corresponding index file, in the Index class, the index is computed based on the clip model defined in the config. Then add that to the index of the dataset. This is done in the Index class
In retrieve we just call the retrieve method in the Index class that gets knn based on the faiss embedding.
In the save_pretrained method, save index using save_faiss_index. Save this dataset along with config.
The call method will just call retrieve.
I'll also have a way to pass the clip model and its components via default arguments.
Test save_pretrained and from_pretrained methods on new dataset.
"""

from transformers import CLIPModel, CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig
from datasets import load_dataset, Image, load_dataset_builder, load_from_disk, Dataset
import torch
from typing import Callable, List, Optional, Union
import numpy as np
from ..utils import deprecate, logging
from transformers.models.rag.retrieval_rag import LegacyIndex, CustomHFIndex, CanonicalHFIndex

logger = logging.get_logger(__name__) # pylint: disable=invalid-name
from diffusers.pipelines.rdm.pipeline_rdm import preprocess_images, normalize_images
import os
from torch.nn import Module


class IndexConfig(PretrainedConfig):
def __init__(
self,
clip_name_or_path="openai/clip-vit-large-patch14",
dataset_name="Isamu136/oxford_pets_with_l14_emb",
image_column="image",
index_name="embeddings",
index_path=None,
dataset_save_path=None,
dataset_set="train",
**kwargs,
):
super().__init__(**kwargs)
self.clip_name_or_path = clip_name_or_path
self.dataset_name = dataset_name
self.image_column = image_column
self.index_name = index_name
self.index_path = index_path
self.dataset_save_path = dataset_save_path
self.dataset_set = dataset_set


class Index:
"""
Each index for a retrieval model is specific to the clip model used and the dataset used.
"""

def __init__(self, config: IndexConfig, dataset: Dataset):
self.config = config
self.dataset = dataset
self.index_initialized = False
self.index_name = config.index_name
self.index_path = config.index_path
self.init_index()

def set_index_name(self, index_name: str):
self.index_name = index_name

def init_index(self):
if not self.index_initialized:
if self.index_path and self.index_name:
try:
self.dataset.load_faiss_index(self.index_name, self.index_path)
self.index_initialized = True
except:
logger.info("Index not initialized")
if self.index_name in self.dataset.features:
self.dataset.add_faiss_index(column=self.index_name)
self.index_initialized = True

def build_index(
self,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
torch_dtype=torch.float32,
):
if not self.index_initialized:
model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained(
self.config.clip_name_or_path
)
self.dataset = get_dataset_with_emb_from_model(
self.dataset,
model,
feature_extractor,
device=model.device,
image_column=self.config.image_column,
index_name=self.config.index_name,
)
self.init_index()

def retrieve_imgs(self, vec, k: int = 20):
vec = np.array(vec).astype(np.float32)
return self.dataset.get_nearest_examples(self.index_name, vec, k=k)

def retrieve_indices(self, vec, k: int = 20):
vec = np.array(vec).astype(np.float32)
return self.dataset.search(self.index_name, vec, k=k)


class Retriever:
def __init__(
self,
config: IndexConfig,
index: Index = None,
dataset: Dataset = None,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
):
self.config = config
self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)

@classmethod
def from_pretrained(
cls,
retriever_name_or_path: str,
index: Index = None,
dataset: Dataset = None,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
**kwargs,
):
config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
return cls(config, index=index, dataset=dataset, model=model, feature_extractor=feature_extractor)

@staticmethod
def _build_index(
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None
):
dataset = dataset or load_dataset(config.dataset_name)
dataset = dataset[config.dataset_set]
index = Index(config, dataset)
index.build_index(model=model, feature_extractor=feature_extractor)
return index

def save_pretrained(self, save_directory):
os.makedirs(save_directory, exist_ok=True)
if self.config.index_path is None:
index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
self.index.dataset.get_index(self.config.index_name).save(index_path)
self.config.index_path = index_path
if self.config.dataset_save_path is None:
dataset_save_path = os.path.join(save_directory, "hf_dataset")
# datasets don't support save_to_disk with indexes right now
faiss_index = self.index.dataset._indexes.pop(self.config.index_name)
self.index.dataset.save_to_disk(dataset_save_path)
self.index.dataset._indexes[self.config.index_name] = faiss_index
self.config.dataset_save_path = dataset_save_path
self.config.save_pretrained(save_directory)

def init_retrieval(self):
logger.info("initializing retrieval")
self.index.init_index()

def retrieve_imgs(self, embeddings: np.ndarray, k: int):
return self.index.retrieve_imgs(embeddings, k)

def retrieve_indices(self, embeddings: np.ndarray, k: int):
return self.index.retrieve_indices(embeddings, k)

def __call__(
self,
embeddings,
k: int = 20,
):
return self.index.retrieve_imgs(embeddings, k)


def map_txt_to_clip_feature(clip_model, tokenizer, prompt):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > tokenizer.model_max_length:
removed_text = tokenizer.batch_decode(text_input_ids[:, tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : tokenizer.model_max_length]
text_embeddings = clip_model.get_text_features(text_input_ids.to(clip_model.device))
text_embeddings = text_embeddings / torch.linalg.norm(text_embeddings, dim=-1, keepdim=True)
text_embeddings = text_embeddings[:, None, :]
return text_embeddings[0][0].cpu().detach().numpy()


def map_img_to_model_feature(model, feature_extractor, imgs):
for i, image in enumerate(imgs):
if not image.mode == "RGB":
imgs[i] = image.convert("RGB")
imgs = normalize_images(imgs)
retrieved_images = preprocess_images(imgs, feature_extractor).to(model.device)
image_embeddings = model(retrieved_images)
image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)
image_embeddings = image_embeddings[None, ...]
return image_embeddings


def get_dataset_with_emb_from_model(dataset, model, feature_extractor, image_column="image", index_name="embeddings"):
return dataset.map(
lambda example: {
index_name: map_img_to_model_feature(model, feature_extractor, [example[image_column]], model.device)
.cpu()
.detach()
.numpy()[0][0]
}
)


def get_dataset_with_emb_from_clip_model(
dataset, clip_model, feature_extractor, image_column="image", index_name="embeddings"
):
return dataset.map(
lambda example: {
index_name: map_img_to_model_feature(
clip_model.get_image_features, feature_extractor, [example[image_column]], clip_model.device
)
.cpu()
.detach()
.numpy()[0][0]
}
)
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .rdm import RDMPipeline

try:
if not is_onnx_available():
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/rdm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ...utils import is_torch_available, is_transformers_available


if is_transformers_available() and is_torch_available():
from .pipeline_rdm import RDMPipeline
Loading