-
Notifications
You must be signed in to change notification settings - Fork 92
Description
The CEBRA documentation is very comprehensive and presents in a lot of detail the parameterization.
In the current form however the focus seems to explain the scikit-learn API and there is no example script for using the PyTorch API: https://cebra.ai/docs/usage.html
But for many options I am unsure how to parametrize them in the scikit-learn API. For example when using discrete behavioral data, it's currently not possible to specify empirical or discretesampling:
CEBRA/cebra/data/single_session.py
Line 89 in 0378db0
| prior: str = dataclasses.field( |
I think this is also intended to not overload the cebra.Cebra intialization or the model.fit() function with too many parameters?
Therefore I thought that maybe adding a minimal example in the usage.rst of how a dataloader with "non-scikitlearn API" conform parameters could be used using PyTorch directly:
import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch
neural_data = cebra.load_data(file="neural_data.npz", key="neural")
# continuous_label = cebra.load_data(
# file="auxiliary_behavior_data.h5",
# key="auxiliary_variables",
# columns=["continuous1", "continuous2", "continuous3"],
# )
discrete_label = cebra.load_data(
file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)
# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
torch.from_numpy(neural_data).type(torch.FloatTensor),
# continuous=torch.from_numpy(np.array(continuous_label)).type(torch.FloatTensor),
discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to("cpu")
# 2. Define Cebra Model
neural_model = cebra.models.init(
name="offset10-model",
num_neurons=InputData.input_dimension,
num_units=32,
num_output=2,
).to("cpu")
InputData.configure_for(neural_model)
# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
# temperature=0.001,
# min_temperature=0.0001
).to("cpu")
Opt = torch.optim.Adam(
list(neural_model.parameters()) + list(Crit.parameters()),
# lr=0.001,
weight_decay=0,
)
# 4. Initialize Cebra Model
cebra_model = cebra.solver.init(
name="single-session",
model=neural_model,
criterion=Crit,
optimizer=Opt,
tqdm_on=True,
).to("cpu")
# 5. Define Data Loader
# loader = cebra.data.single_session.ContinuousDataLoader(
# dataset=InputData, num_steps=1000, batch_size=200
# )
loader = cebra.data.single_session.DiscreteDataLoader(
dataset=InputData, num_steps=1000, batch_size=200, prior="uniform"
)
# 6. Fit model
cebra_model.fit(loader=loader)
# 7. Transform embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
neural_data, neural_model.get_offset().__len__(), axis=0
)
X_train_emb = cebra_model.transform(
torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to("cpu")
).to("cpu")
# 8. Potentially plot embedding
plot_embedding(
X_train_emb,
discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
markersize=10,
)