-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
51 lines (37 loc) · 1.3 KB
/
main.py
File metadata and controls
51 lines (37 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from run_setup import setup_env, setup_plot
setup_env()
setup_plot()
from weak_supervision_labeling.experiment import run
if __name__ == "__main__":
# model seed
seed = 0
# number of training epochs
n_epochs = 450
# dataset
dataset = "emnist" # available: "mnist", "emnist"
# fraction of labeled data (between 0 and 1)
label_map_frac_eval = 0.002
# whether to display titles in plots
titles_plot = True
# whether to skip the UMAP representation of the latent space
skip_umap = True
# whether to skip the t-SNE representation of the latent space
skip_tsne = True
# supervised baselines to compare
supervised_baselines = ["logreg", "mlp", "xgboost"] # available: "logreg", "mlp", "xgboost"
# number of seeds to compare baselines to the implemented method
label_map_n_seeds = 5
run(
seed=seed,
n_epochs=n_epochs,
dataset=dataset,
titles_plot=titles_plot,
skip_umap=skip_umap,
skip_tsne=skip_tsne,
label_map_frac_eval=label_map_frac_eval,
label_map_stratified=False,
label_map_fracs=[5e-5, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 5e-1],
supervised_baselines=supervised_baselines,
label_map_n_seeds=label_map_n_seeds,
save_gmvae_model=True,
)