From fecc12eed17c1cbd56d22c81d007bb79687f4d61 Mon Sep 17 00:00:00 2001 From: Christopher Earl <40307516+Cr0uton@users.noreply.github.com> Date: Thu, 10 Jun 2021 18:38:52 -0400 Subject: [PATCH] Update reservoir.py --- examples/mnist/reservoir.py | 38 ++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/examples/mnist/reservoir.py b/examples/mnist/reservoir.py index 4c5ef80b..1370c0dd 100644 --- a/examples/mnist/reservoir.py +++ b/examples/mnist/reservoir.py @@ -73,6 +73,7 @@ torch.set_num_threads(os.cpu_count() - 1) print("Running on Device = ", device) +# Create simple Torch NN network = Network(dt=dt) inpt = Input(784, shape=(1, 28, 28)) network.add_layer(inpt, name="I") @@ -84,6 +85,7 @@ network.add_connection(C1, source="I", target="O") network.add_connection(C2, source="O", target="O") +# Monitors for visualizing activity spikes = {} for l in network.layers: spikes[l] = Monitor(network.layers[l], ["s"], time=time, device=device) @@ -101,7 +103,7 @@ dataset = MNIST( PoissonEncoder(time=time, dt=dt), None, - root=os.path.join("..", "..", "data", "MNIST"), + root=os.path.join("..", "data", "MNIST"), download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] @@ -123,19 +125,26 @@ ) # Run training data on reservoir computer and store (spikes per neuron, label) per example. +# Note: Because this is a reservoir network, no adjustments of neuron parameters occurs in this phase. n_iters = examples training_pairs = [] pbar = tqdm(enumerate(dataloader)) for (i, dataPoint) in pbar: if i > n_iters: break + + # Extract & resize the MNIST samples image data for training + # int(time / dt) -> length of spike train + # 28 x 28 -> size of sample datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device) label = dataPoint["label"] pbar.set_description_str("Train progress: (%d / %d)" % (i, n_iters)) + # Run network on sample image network.run(inputs={"I": datum}, time=time, input_time_dim=1) training_pairs.append([spikes["O"].get("s").sum(0), label]) + # Plot spiking activity using monitors if plot: inpt_axes, inpt_ims = plot_input( @@ -165,6 +174,7 @@ # Define logistic regression model using PyTorch. +# These neurons will take the reservoirs output as its input, and be trained to classify the images. class NN(nn.Module): def __init__(self, input_size, num_classes): super(NN, self).__init__() @@ -189,14 +199,26 @@ def forward(self, x): pbar = tqdm(enumerate(range(n_epochs))) for epoch, _ in pbar: avg_loss = 0 + + # Extract spike outputs from reservoir for a training sample + # i -> Loop index + # s -> Reservoir output spikes + # l -> Image label for i, (s, l) in enumerate(training_pairs): - # Forward + Backward + Optimize + + # Reset gradients to 0 optimizer.zero_grad() + + # Run spikes through logistic regression model outputs = model(s) + + # Calculate MSE label = torch.zeros(1, 1, 10).float().to(device) label[0, 0, l] = 1.0 loss = criterion(outputs.view(1, 1, -1), label) avg_loss += loss.data + + # Optimize parameters loss.backward() optimizer.step() @@ -205,17 +227,19 @@ def forward(self, x): % (epoch + 1, n_epochs, avg_loss / len(training_pairs)) ) +# Run same simulation on reservoir with testing data instead of training data +# (see training section for intuition) n_iters = examples test_pairs = [] pbar = tqdm(enumerate(dataloader)) for (i, dataPoint) in pbar: if i > n_iters: break - datum = dataPoint["encoded_image"].view(time, 1, 1, 28, 28).to(device) + datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device) label = dataPoint["label"] pbar.set_description_str("Testing progress: (%d / %d)" % (i, n_iters)) - network.run(inputs={"I": datum}, time=250, input_time_dim=1) + network.run(inputs={"I": datum}, time=time, input_time_dim=1) test_pairs.append([spikes["O"].get("s").sum(0), label]) if plot: @@ -227,12 +251,12 @@ def forward(self, x): ims=inpt_ims, ) spike_ims, spike_axes = plot_spikes( - {layer: spikes[layer].get("s").view(-1, 250) for layer in spikes}, + {layer: spikes[layer].get("s").view(time, -1) for layer in spikes}, axes=spike_axes, ims=spike_ims, ) voltage_ims, voltage_axes = plot_voltages( - {layer: voltages[layer].get("v").view(-1, 250) for layer in voltages}, + {layer: voltages[layer].get("v").view(time, -1) for layer in voltages}, ims=voltage_ims, axes=voltage_axes, ) @@ -244,7 +268,7 @@ def forward(self, x): plt.pause(1e-8) network.reset_state_variables() -# Test the Model +# Test model with previously trained logistic regression classifier correct, total = 0, 0 for s, label in test_pairs: outputs = model(s)