diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index ea572bc8..c5b19a5d 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -28,15 +28,16 @@ def __init__( state_vars: Iterable[str], time: Optional[int] = None, batch_size: int = 1, + device: str = "cpu", ): # language=rst """ Constructs a ``Monitor`` object. :param obj: An object to record state variables from during network simulation. - :param state_vars: Iterable of strings indicating names of state variables to - record. + :param state_vars: Iterable of strings indicating names of state variables to record. :param time: If not ``None``, pre-allocate memory for state variable recording. + :param device: Allow the monitor to be on different device separate from Network device """ super().__init__() @@ -44,9 +45,14 @@ def __init__( self.state_vars = state_vars self.time = time self.batch_size = batch_size + self.device = device + + # if time is not specified the monitor variable accumulate the logs + if self.time is None: + self.device = "cpu" - # Deal with time later, the same underlying list is used - self.recording = {v: [] for v in self.state_vars} + self.recording = [] + self.reset_state_variables() def get(self, var: str) -> torch.Tensor: # language=rst @@ -54,10 +60,15 @@ def get(self, var: str) -> torch.Tensor: Return recording to user. :param var: State variable recording to return. - :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is - the shape of the recorded state variable. + :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded state + variable. + Note, if time == `None`, get return the logs and empty the monitor variable + """ - return torch.cat(self.recording[var], 0) + return_logs = torch.cat(self.recording[var], 0) + if self.time is None: + self.recording[var] = [] + return return_logs def record(self) -> None: # language=rst @@ -66,20 +77,27 @@ def record(self) -> None: """ for v in self.state_vars: data = getattr(self.obj, v).unsqueeze(0) - self.recording[v].append(data.detach().clone()) - - # remove the oldest element (first in the list) - if self.time is not None: - for v in self.state_vars: - if len(self.recording[v]) > self.time: - self.recording[v].pop(0) + # self.recording[v].append(data.detach().clone().to(self.device)) + self.recording[v].append( + torch.empty_like(data, device=self.device, requires_grad=False).copy_( + data, non_blocking=True + ) + ) + # remove the oldest element (first in the list) + if self.time is not None: + self.recording[v].pop(0) def reset_state_variables(self) -> None: # language=rst """ - Resets recordings to empty ``torch.Tensor``s. + Resets recordings to empty ``List``s. """ - self.recording = {v: [] for v in self.state_vars} + if self.time is None: + self.recording = {v: [] for v in self.state_vars} + else: + self.recording = { + v: [[] for i in range(self.time)] for v in self.state_vars + } class NetworkMonitor(AbstractMonitor): diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 4b45b373..7c721383 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -420,7 +420,7 @@ class LIFNodes(Nodes): # language=rst """ Layer of `leaky integrate-and-fire (LIF) neurons - `_. + `_. """ def __init__( @@ -683,7 +683,7 @@ class CurrentLIFNodes(Nodes): # language=rst """ Layer of `current-based leaky integrate-and-fire (LIF) neurons - `_. + `_. Total synaptic input current is modeled as a decaying memory of input spikes multiplied by synaptic strengths. """ @@ -1148,7 +1148,7 @@ def set_batch_size(self, batch_size) -> None: class IzhikevichNodes(Nodes): # language=rst """ - Layer of Izhikevich neurons. + Layer of `Izhikevich neurons`_. """ def __init__( diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 387ed90d..17eea96a 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -160,7 +160,7 @@ def __init__( w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: - w = torch.clamp(w, self.wmin, self.wmax) + w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) @@ -381,10 +381,10 @@ def normalize(self) -> None: if self.norm is not None: # get a view and modify in place w = self.w.view( - self.w.size(0) * self.w.size(1), self.w.size(2) * self.w.size(3) + self.w.shape[0] * self.w.shape[1], self.w.shape[2] * self.w.shape[3] ) - for fltr in range(w.size(0)): + for fltr in range(w.shape[0]): w[fltr] *= self.norm / w[fltr].sum(0) def reset_state_variables(self) -> None: diff --git a/examples/mnist/plots/assaiments/assaiments.png b/examples/mnist/plots/assaiments/assaiments.png new file mode 100644 index 00000000..488d785c Binary files /dev/null and b/examples/mnist/plots/assaiments/assaiments.png differ diff --git a/examples/mnist/plots/performance/performance.png b/examples/mnist/plots/performance/performance.png new file mode 100644 index 00000000..ca5f897b Binary files /dev/null and b/examples/mnist/plots/performance/performance.png differ diff --git a/examples/mnist/plots/weights/weights.1.png b/examples/mnist/plots/weights/weights.1.png new file mode 100644 index 00000000..1f9495b1 Binary files /dev/null and b/examples/mnist/plots/weights/weights.1.png differ