Add model arcinstitute state#39480
Conversation
|
Paper for reference! https://www.biorxiv.org/content/10.1101/2025.06.26.661135v2 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto |
|
@abhinadduri for ref |
|
cc @Cyrilvallez camn you have a look?! |
|
Hey! I'm a bit confused with the PR right now! What model are we adding? If it's |
|
Hey @Cyrilvallez, Arc has created 2 models, We will clean up the PR before pinging for a proper review! |
|
Nice, thanks for the explanations @FL33TW00D! Makes sense! Let me know when you believe this is ready for review then 🤗 Just a heads-up that we want the folder/files names as full snake_case, e.g. |
|
thanks everyone! we are starting our review now, cc @Rive-001 |
|
|
||
| class StateTxConfig(PretrainedConfig): | ||
| r""" | ||
| Configuration class for StateTx (State Transformer) model based on PertSetsPerturbationModel. |
There was a problem hiding this comment.
Low priority: We renamed the model class from PertSetsPerturbationModel to StateTransitionPerturbationModel.
| r""" | ||
| Configuration class for StateTx (State Transformer) model based on PertSetsPerturbationModel. | ||
|
|
||
| This model uses a bidirectional Llama transformer backbone to process perturbation data. |
There was a problem hiding this comment.
Low priority: We currently support bi-directional versions of both Llama and GPT-2.
| from .configuration_state_tx import LlamaBidirectionalConfig, StateTxConfig | ||
|
|
||
|
|
||
| class SamplesLoss(nn.Module): |
There was a problem hiding this comment.
We currently support MSE and MMD loss functions. We use SampleLoss from the geom library. SampleLoss documentation
| return F.mse_loss(predictions, targets) | ||
|
|
||
|
|
||
| class LatentToGeneDecoder(nn.Module): |
There was a problem hiding this comment.
We currently support user input for number of layers in the decoder, dimensions of those layers and optional residual connections between layers.
| def __init__(self, config: StateTxConfig): | ||
| super().__init__() | ||
| self.decoder = nn.Sequential( | ||
| nn.Linear(config.gene_dim, 1024, bias=True), |
There was a problem hiding this comment.
This doesn't seem correct, as the input dimension would be a latent dimension and not the gene dimensions. The gene dimension is the output dimension.
| # batch_embeds = self.batch_encoder(batch_ids) # (batch_size, hidden_dim) | ||
| # # Add batch embedding to each position | ||
| # combined_input = combined_input + batch_embeds.unsqueeze(1) | ||
| batch_embeddings = self.batch_encoder(torch.zeros([512]).long()).unsqueeze(1) |
There was a problem hiding this comment.
Hardcoding batch embeddings to 0s might not be correct.
|
|
||
| # Binary classification decoder | ||
| # binary_input_dim = config.output_dim + config.d_model + config.z_dim_rd + config.z_dim_ds | ||
| binary_input_dim = 4107 |
There was a problem hiding this comment.
We might want the dimensions of this decoder to be based on the config values.
This PR adds the arc state model
Run embedding model via transformers
sanity.py
outputs
Compare to reference
next, apply this small patch so we can run the model file directly with a fixed input to compare with the impl above
file `compare.patch`
can be applied like
# save above as compare.patch git apply compare.patchrun the model
output
!!! Using Flash Attention !!! Input sum: tensor(-38.6611) Mask sum: tensor(2560) Output shape: tensor(-19.6819, grad_fn=<SumBackward0>)