Skip to content

Adding example script with custom Loader for PyTorch API documentation #97

@timonmerk

Description

@timonmerk

The CEBRA documentation is very comprehensive and presents in a lot of detail the parameterization.
In the current form however the focus seems to explain the scikit-learn API and there is no example script for using the PyTorch API: https://cebra.ai/docs/usage.html

But for many options I am unsure how to parametrize them in the scikit-learn API. For example when using discrete behavioral data, it's currently not possible to specify empirical or discretesampling:

prior: str = dataclasses.field(

I think this is also intended to not overload the cebra.Cebra intialization or the model.fit() function with too many parameters?

Therefore I thought that maybe adding a minimal example in the usage.rst of how a dataloader with "non-scikitlearn API" conform parameters could be used using PyTorch directly:

import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch

neural_data = cebra.load_data(file="neural_data.npz", key="neural")
# continuous_label = cebra.load_data(
#    file="auxiliary_behavior_data.h5",
#    key="auxiliary_variables",
#    columns=["continuous1", "continuous2", "continuous3"],
# )

discrete_label = cebra.load_data(
    file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
    torch.from_numpy(neural_data).type(torch.FloatTensor),
    # continuous=torch.from_numpy(np.array(continuous_label)).type(torch.FloatTensor),
    discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to("cpu")

# 2. Define Cebra Model
neural_model = cebra.models.init(
    name="offset10-model",
    num_neurons=InputData.input_dimension,
    num_units=32,
    num_output=2,
).to("cpu")

InputData.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
    # temperature=0.001,
    # min_temperature=0.0001
).to("cpu")

Opt = torch.optim.Adam(
    list(neural_model.parameters()) + list(Crit.parameters()),
    # lr=0.001,
    weight_decay=0,
)

# 4. Initialize Cebra Model
cebra_model = cebra.solver.init(
    name="single-session",
    model=neural_model,
    criterion=Crit,
    optimizer=Opt,
    tqdm_on=True,
).to("cpu")

# 5. Define Data Loader
# loader = cebra.data.single_session.ContinuousDataLoader(
#    dataset=InputData, num_steps=1000, batch_size=200
# )
loader = cebra.data.single_session.DiscreteDataLoader(
    dataset=InputData, num_steps=1000, batch_size=200, prior="uniform"
)

# 6. Fit model
cebra_model.fit(loader=loader)

# 7. Transform embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
    neural_data, neural_model.get_offset().__len__(), axis=0
)
X_train_emb = cebra_model.transform(
    torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to("cpu")
).to("cpu")

# 8. Potentially plot embedding
plot_embedding(
    X_train_emb,
    discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
    markersize=10,
)

Metadata

Metadata

Assignees

Labels

documentationImprovements or additions to documentationenhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions