Authors: Yang Chen, Xiaowei Xu, Shuai Wang, Chenhui Zhu, Ruxue Wen, Xubin Li, Tiezheng Ge, Limin Wang
📧 Primary Contact: [email protected]
This repository hosts the official open-source implementation for FlowBack.
FlowBack introduces a novel alignment strategy for Normalizing Flows (NFs). It works by aligning the features of the generative (reverse) pass with those from a pretrained vision encoder. This approach yields significant improvements in both generative quality and classification accuracy.
FlowBack is a joint project by Nanjing University and Alibaba Group.
- 🚀 Model Training: Easily train our model from scratch on your own datasets.
- 📊 FID Evaluation: Calculate the Fréchet Inception Distance (FID) to measure image generation quality.
- 🎯 Training-Free Classification: Reproduce the paper's classification accuracy metric.
- ✨ Linear Probing: Evaluate the quality of learned representations through linear probing on intermediate features.
-
Clone the repository:
git clone https://github.com/MCG-NJU/FlowBack.git cd FlowBack -
Create a virtual environment (recommended):
conda create -n flowback python=3.10 conda activate flowback pip install -r requirements.txt
This section outlines the main workflows for using this repository: training a new model and evaluating it using various metrics.
To train the model, use the provided training script train.sh script.
bash scripts/train.shAfter training, you can evaluate your model using the following metrics.
To evaluate the Fréchet Inception Distance (FID), first, you need to generate the pre-computed statistics file for the target dataset (e.g., ImageNet).
-
Prepare FID stats:
python prepare_fid_stats.py --data /path/to/imagenet/
-
Calculate FID:
bash scripts/local-eval.sh
To compute the classification accuracy metric as proposed in our paper, run the classify.sh script with your trained model checkpoint.
bash scripts/classify.shLinear probing is a two-step process to evaluate the quality of the model's internal representations.
-
Step 1: Cache Features First, use
cache_repa_features.pyto extract and save the intermediate features from your trained model.accelerate launch cache_repa_features.py
-
Step 2: Train Linear Probe Once the features are cached, run
lp.pyto train a linear classifier on top of these features and compute the final classification accuracy.accelerate launch lp.py
This project is built upon TARFlow and FlowDCN. Thanks to the contributors of these great codebases.
