Skip to content

MCG-NJU/FlowBack

Repository files navigation

Flowing Backwards: Improving Normalizing Flows via Reverse Representation Alignment

Authors: Yang Chen, Xiaowei Xu, Shuai Wang, Chenhui Zhu, Ruxue Wen, Xubin Li, Tiezheng Ge, Limin Wang

📧 Primary Contact: [email protected]

🔍 Overview

This repository hosts the official open-source implementation for FlowBack.

FlowBack introduces a novel alignment strategy for Normalizing Flows (NFs). It works by aligning the features of the generative (reverse) pass with those from a pretrained vision encoder. This approach yields significant improvements in both generative quality and classification accuracy.

FlowBack is a joint project by Nanjing University and Alibaba Group.

Overview

🌟 Features

  • 🚀 Model Training: Easily train our model from scratch on your own datasets.
  • 📊 FID Evaluation: Calculate the Fréchet Inception Distance (FID) to measure image generation quality.
  • 🎯 Training-Free Classification: Reproduce the paper's classification accuracy metric.
  • Linear Probing: Evaluate the quality of learned representations through linear probing on intermediate features.

⚙️ Installation

  1. Clone the repository:

    git clone https://github.com/MCG-NJU/FlowBack.git
    cd FlowBack
  2. Create a virtual environment (recommended):

    conda create -n flowback python=3.10
    conda activate flowback 
    pip install -r requirements.txt

🚀 Usage

This section outlines the main workflows for using this repository: training a new model and evaluating it using various metrics.

1. Model Training

To train the model, use the provided training script train.sh script.

bash scripts/train.sh

2. Evaluation

After training, you can evaluate your model using the following metrics.

📊 a) FID Score Evaluation

To evaluate the Fréchet Inception Distance (FID), first, you need to generate the pre-computed statistics file for the target dataset (e.g., ImageNet).

  1. Prepare FID stats:

    python prepare_fid_stats.py --data /path/to/imagenet/ 
  2. Calculate FID:

    bash scripts/local-eval.sh

🎯 b) Classification Accuracy

To compute the classification accuracy metric as proposed in our paper, run the classify.sh script with your trained model checkpoint.

bash scripts/classify.sh

✨ c) Linear Probing

Linear probing is a two-step process to evaluate the quality of the model's internal representations.

  1. Step 1: Cache Features First, use cache_repa_features.py to extract and save the intermediate features from your trained model.

    accelerate launch cache_repa_features.py 
  2. Step 2: Train Linear Probe Once the features are cached, run lp.py to train a linear classifier on top of these features and compute the final classification accuracy.

    accelerate launch lp.py 

💐 Acknowledgements

This project is built upon TARFlow and FlowDCN. Thanks to the contributors of these great codebases.

About

[AAAI 2026] Flowing Backwards: Improving Normalizing Flows via Reverse Representation Alignment

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published