diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 8ef3cb7c..95b99d51 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -191,7 +191,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: def compute_window(self, s: torch.Tensor) -> torch.Tensor: # language=rst - """""" + """ """ if self.s_w == None: # Construct a matrix of shape batch size * window size * dimension of layer diff --git a/examples/mnist/SOM_LM-SNNs.py b/examples/mnist/SOM_LM-SNNs.py index e7726e0f..69b1be77 100644 --- a/examples/mnist/SOM_LM-SNNs.py +++ b/examples/mnist/SOM_LM-SNNs.py @@ -341,7 +341,7 @@ pbar = tqdm(total=n_test) for step, batch in enumerate(test_dataset): - if step > n_test: + if step >= n_test: break # Get next input sample. inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)} diff --git a/examples/mnist/eth_mnist.py b/examples/mnist/eth_mnist.py index 30fd3411..c96aa90a 100644 --- a/examples/mnist/eth_mnist.py +++ b/examples/mnist/eth_mnist.py @@ -14,11 +14,7 @@ from bindsnet.models import DiehlAndCook2015 from bindsnet.network.monitors import Monitor from bindsnet.utils import get_square_weights, get_square_assignments -from bindsnet.evaluation import ( - all_activity, - proportion_weighting, - assign_labels, -) +from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels from bindsnet.analysis.plotting import ( plot_input, plot_spikes, @@ -168,8 +164,8 @@ # Train the network. print("\nBegin training.\n") start = t() -labels = [] for epoch in range(n_epochs): + labels = [] if epoch % progress_interval == 0: print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) @@ -194,9 +190,7 @@ # Get network predictions. all_activity_pred = all_activity( - spikes=spike_record, - assignments=assignments, - n_labels=n_classes, + spikes=spike_record, assignments=assignments, n_labels=n_classes ) proportion_pred = proportion_weighting( spikes=spike_record, @@ -312,7 +306,7 @@ pbar = tqdm(total=n_test) for step, batch in enumerate(test_dataset): - if step > n_test: + if step >= n_test: break # Get next input sample. inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)}