Skip to content

Add model arcinstitute state#39480

Draft
drbh wants to merge 8 commits into
mainfrom
add-model-arcinstitute-state
Draft

Add model arcinstitute state#39480
drbh wants to merge 8 commits into
mainfrom
add-model-arcinstitute-state

Conversation

@drbh

@drbh drbh commented Jul 17, 2025

Copy link
Copy Markdown
Contributor

This PR adds the arc state model

Run embedding model via transformers

git clone https://github.com/huggingface/transformers
git checkout add-model-arcinstitute-state
uv run sanity.py

sanity.py

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "torch",
#     "transformers"
# ]
#
# [tool.uv.sources]
# transformers = { path = ".", editable = true }
# ///
import torch
from transformers import StateEmbeddingModel


model_name = "arcinstitute/SE-600M"
model = StateEmbeddingModel.from_pretrained(model_name)

torch.manual_seed(0)
input_ids = torch.randn((1, 1, 5120), dtype=torch.float32)
mask = torch.ones((1, 1, 5120), dtype=torch.bool)
mask[:, :, 2560:] = False # simulate half masking
print("Input sum:\t", input_ids.sum())
print("Mask sum:\t", mask.sum())

outputs = model(input_ids, mask)
print("Output sum:\t", outputs["gene_output"].sum())

outputs

Input sum:	 tensor(-38.6611)
Mask sum:	 tensor(2560)
Output sum:	 tensor(-19.6819, grad_fn=<SumBackward0>)

Compare to reference

git clone https://github.com/ArcInstitute/state.git
cd state
curl -OL https://huggingface.co/arcinstitute/SE-600M/resolve/main/se600m_epoch16.ckpt

next, apply this small patch so we can run the model file directly with a fixed input to compare with the impl above

file `compare.patch`
diff --git a/src/state/emb/nn/model.py b/src/state/emb/nn/model.py
index dbbefb3..42167a1 100644
--- a/src/state/emb/nn/model.py
+++ b/src/state/emb/nn/model.py
@@ -23,20 +23,20 @@ from torch.nn import TransformerEncoder, TransformerEncoderLayer, BCEWithLogitsL
 from tqdm.auto import tqdm
 from torch.optim.lr_scheduler import ChainedScheduler, LinearLR, CosineAnnealingLR, ReduceLROnPlateau
 
-from ..data import create_dataloader
-from ..utils import (
+from state.emb.data import create_dataloader
+from state.emb.utils import (
     compute_gene_overlap_cross_pert,
     get_embedding_cfg,
     get_dataset_cfg,
     compute_pearson_delta,
     compute_perturbation_ranking_score,
 )
-from ..eval.emb import cluster_embedding
-from .loss import WassersteinLoss, KLDivergenceLoss, MMDLoss, TabularLoss
+from state.emb.eval.emb import cluster_embedding
+from loss import WassersteinLoss, KLDivergenceLoss, MMDLoss, TabularLoss
 
 
-from .flash_transformer import FlashTransformerEncoderLayer
-from .flash_transformer import FlashTransformerEncoder
+from flash_transformer import FlashTransformerEncoderLayer
+from flash_transformer import FlashTransformerEncoder
 
 
 class SkipBlock(nn.Module):
@@ -196,7 +196,8 @@ class StateEmbeddingModel(L.LightningModule):
             self.dataset_embedder = nn.Linear(output_dim, 10)
 
             # Assume self.cfg.model.num_datasets is set to the number of unique datasets.
-            num_dataset = get_dataset_cfg(self.cfg).num_datasets
+            # num_dataset = get_dataset_cfg(self.cfg).num_datasets
+            num_dataset = 14420 
             self.dataset_encoder = nn.Sequential(
                 nn.Linear(output_dim, d_model),
                 nn.SiLU(),
@@ -686,3 +687,18 @@ class StateEmbeddingModel(L.LightningModule):
             "optimizer": optimizer,
             "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss", "interval": "step", "frequency": 1},
         }
+
+if __name__ == "__main__":
+    checkpoint = "/Users/drbh/Projects/state/se600m_epoch16.ckpt"
+    model = StateEmbeddingModel.load_from_checkpoint(checkpoint, dropout=0.0, strict=False)
+
+    torch.manual_seed(0)
+
+    input_ids = torch.randn((1, 1, 5120), dtype=torch.float32)
+    mask = torch.ones((1, 1, 5120), dtype=torch.bool)
+    mask[:, :, 2560:] = False
+    print("Input sum:\t", input_ids.sum())
+    print("Mask sum:\t", mask.sum())
+
+    output, embedding, dataset_emb = model(input_ids, mask)
+    print("Output shape:\t", output.sum())

can be applied like

# save above as compare.patch
git apply compare.patch

run the model

.venv/bin/python src/state/emb/nn/model.py

output

!!! Using Flash Attention !!!
Input sum: tensor(-38.6611)
Mask sum: tensor(2560)
Output shape: tensor(-19.6819, grad_fn=<SumBackward0>)

@FL33TW00D

Copy link
Copy Markdown

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

@FL33TW00D

Copy link
Copy Markdown

@abhinadduri for ref

@ArthurZucker

Copy link
Copy Markdown
Collaborator

cc @Cyrilvallez camn you have a look?!

@Cyrilvallez

Copy link
Copy Markdown
Member

Hey! I'm a bit confused with the PR right now! What model are we adding? If it's arcinstitute, we only want modeling file refering to it! No general names such as StateEmbedding!
But I see that you incorporated modular, which is very good! So let's fix a bit/clarify the model names, fix the consistency issues and then we'll be good for a first review! 🤗

@FL33TW00D

Copy link
Copy Markdown

Hey @Cyrilvallez,
Thanks for taking a look!

Arc has created 2 models, StateEmbedding and StateTransition which would be good to add. They are the first group to surpass linear baselines for this problem.

We will clean up the PR before pinging for a proper review!

@Cyrilvallez

Copy link
Copy Markdown
Member

Nice, thanks for the explanations @FL33TW00D! Makes sense! Let me know when you believe this is ready for review then 🤗 Just a heads-up that we want the folder/files names as full snake_case, e.g. state_embedding, and the class name all prefixed with CamelCase, e.g. StateEmbeddingModule! 👌

@abhinadduri

Copy link
Copy Markdown

thanks everyone! we are starting our review now, cc @Rive-001


class StateTxConfig(PretrainedConfig):
r"""
Configuration class for StateTx (State Transformer) model based on PertSetsPerturbationModel.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low priority: We renamed the model class from PertSetsPerturbationModel to StateTransitionPerturbationModel.

r"""
Configuration class for StateTx (State Transformer) model based on PertSetsPerturbationModel.

This model uses a bidirectional Llama transformer backbone to process perturbation data.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low priority: We currently support bi-directional versions of both Llama and GPT-2.

from .configuration_state_tx import LlamaBidirectionalConfig, StateTxConfig


class SamplesLoss(nn.Module):

@Rive-001 Rive-001 Aug 11, 2025

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return F.mse_loss(predictions, targets)


class LatentToGeneDecoder(nn.Module):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently support user input for number of layers in the decoder, dimensions of those layers and optional residual connections between layers.

https://github.com/ArcInstitute/state/blob/be0006c4556327431bda29b6db1b7b223d9eda8c/src/state/tx/models/base.py#L15-L116

def __init__(self, config: StateTxConfig):
super().__init__()
self.decoder = nn.Sequential(
nn.Linear(config.gene_dim, 1024, bias=True),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem correct, as the input dimension would be a latent dimension and not the gene dimensions. The gene dimension is the output dimension.

https://github.com/ArcInstitute/state/blob/be0006c4556327431bda29b6db1b7b223d9eda8c/src/state/tx/models/base.py#L47

# batch_embeds = self.batch_encoder(batch_ids) # (batch_size, hidden_dim)
# # Add batch embedding to each position
# combined_input = combined_input + batch_embeds.unsqueeze(1)
batch_embeddings = self.batch_encoder(torch.zeros([512]).long()).unsqueeze(1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding batch embeddings to 0s might not be correct.

@ArthurZucker ArthurZucker requested review from Cyrilvallez and removed request for ArthurZucker August 13, 2025 09:02

# Binary classification decoder
# binary_input_dim = config.output_dim + config.d_model + config.z_dim_rd + config.z_dim_ds
binary_input_dim = 4107

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want the dimensions of this decoder to be based on the config values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants