diff --git a/bindsnet/encoding/encoders.py b/bindsnet/encoding/encoders.py index e17a4f22..111e939f 100644 --- a/bindsnet/encoding/encoders.py +++ b/bindsnet/encoding/encoders.py @@ -86,7 +86,7 @@ def __init__(self, time: int, dt: float = 1.0, **kwargs): class PoissonEncoder(Encoder): - def __init__(self, time: int, dt: float = 1.0, **kwargs): + def __init__(self, time: int, dt: float = 1.0, approx: bool = False, **kwargs): # language=rst """ Creates a callable PoissonEncoder which encodes as defined in @@ -94,8 +94,10 @@ def __init__(self, time: int, dt: float = 1.0, **kwargs): :param time: Length of Poisson spike train per input variable. :param dt: Simulation time step. + :param approx: Bool: use alternate faster, less accurate computation. + """ - super().__init__(time, dt=dt, **kwargs) + super().__init__(time, dt=dt, approx=approx, **kwargs) self.enc = encodings.poisson diff --git a/bindsnet/encoding/encodings.py b/bindsnet/encoding/encodings.py index 18809430..63ae10bf 100644 --- a/bindsnet/encoding/encodings.py +++ b/bindsnet/encoding/encodings.py @@ -98,7 +98,12 @@ def bernoulli( def poisson( - datum: torch.Tensor, time: int, dt: float = 1.0, device="cpu", **kwargs + datum: torch.Tensor, + time: int, + dt: float = 1.0, + device="cpu", + approx=False, + **kwargs ) -> torch.Tensor: # language=rst """ @@ -110,6 +115,8 @@ def poisson( :param datum: Tensor of shape ``[n_1, ..., n_k]``. :param time: Length of Poisson spike train per input variable. :param dt: Simulation time step. + :param device: target destination of poisson spikes. + :param approx: Bool: use alternate faster, less accurate computation. :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes. """ assert (datum >= 0).all(), "Inputs must be non-negative" @@ -119,27 +126,35 @@ def poisson( datum = datum.flatten() time = int(time / dt) - # Compute firing rates in seconds as function of data intensity, - # accounting for simulation time step. - rate = torch.zeros(size, device=device) - rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt) + if approx: + # random normal power awful approximation + x = torch.randn((time, size), device=device).abs() + x = torch.pow(x, (datum * 0.11 + 5) / 50) + y = torch.tensor(x < 0.6, dtype=torch.bool, device=device) - # Create Poisson distribution and sample inter-spike intervals - # (incrementing by 1 to avoid zero intervals). - dist = torch.distributions.Poisson(rate=rate) - intervals = dist.sample(sample_shape=torch.Size([time + 1])) - intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float() - - # Calculate spike times by cumulatively summing over time dimension. - times = torch.cumsum(intervals, dim=0).long() - times[times >= time + 1] = 0 - - # Create tensor of spikes. - spikes = torch.zeros(time + 1, size, device=device).byte() - spikes[times, torch.arange(size)] = 1 - spikes = spikes[1:] - - return spikes.view(time, *shape) + return y.view(time, *shape).byte() + else: + # Compute firing rates in seconds as function of data intensity, + # accounting for simulation time step. + rate = torch.zeros(size, device=device) + rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt) + + # Create Poisson distribution and sample inter-spike intervals + # (incrementing by 1 to avoid zero intervals). + dist = torch.distributions.Poisson(rate=rate) + intervals = dist.sample(sample_shape=torch.Size([time + 1])) + intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float() + + # Calculate spike times by cumulatively summing over time dimension. + times = torch.cumsum(intervals, dim=0).long() + times[times >= time + 1] = 0 + + # Create tensor of spikes. + spikes = torch.zeros(time + 1, size, device=device).byte() + spikes[times, torch.arange(size)] = 1 + spikes = spikes[1:] + + return spikes.view(time, *shape) def rank_order( diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 62ffb3b5..a204639b 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -52,7 +52,9 @@ def __init__( elif isinstance(nu, float) or isinstance(nu, int): nu = [nu, nu] - self.nu = nu + self.nu = torch.zeros(2, dtype=torch.float) + self.nu[0] = nu[0] + self.nu[1] = nu[1] # Parameter update reduction across minibatch dimension. if reduction is None: @@ -64,7 +66,7 @@ def __init__( self.reduction = reduction # Weight decay. - self.weight_decay = weight_decay + self.weight_decay = 1.0 - weight_decay def update(self) -> None: # language=rst @@ -73,7 +75,7 @@ def update(self) -> None: """ # Implement weight decay. if self.weight_decay: - self.connection.w -= self.weight_decay * self.connection.w + self.connection.w *= self.weight_decay # Bound weights. if ( @@ -177,20 +179,21 @@ def _connection_update(self, **kwargs) -> None: """ batch_size = self.source.batch_size - source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() - source_x = self.source.x.view(batch_size, -1).unsqueeze(2) - target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float() - target_x = self.target.x.view(batch_size, -1).unsqueeze(1) - # Pre-synaptic update. if self.nu[0]: - update = self.reduction(torch.bmm(source_s, target_x), dim=0) - self.connection.w -= self.nu[0] * update + source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() + target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0] + self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0) + del source_s, target_x # Post-synaptic update. if self.nu[1]: - update = self.reduction(torch.bmm(source_x, target_s), dim=0) - self.connection.w += self.nu[1] * update + target_s = ( + self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1] + ) + source_x = self.source.x.view(batch_size, -1).unsqueeze(2) + self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0) + del source_x, target_s super().update() diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 5a413912..6aa65aa3 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -101,7 +101,7 @@ def forward(self, x: torch.Tensor) -> None: if self.traces_additive: self.x += self.trace_scale * self.s.float() else: - self.x.masked_fill_(self.s != 0, 1) + self.x.masked_fill_(self.s, 1) if self.sum_input: # Add current input to running sum. @@ -125,7 +125,7 @@ def compute_decays(self, dt) -> None: """ Abstract base class method for setting decays. """ - self.dt = dt + self.dt = torch.tensor(dt) if self.traces: self.trace_decay = torch.exp( -self.dt / self.tc_trace @@ -560,6 +560,125 @@ def set_batch_size(self, batch_size) -> None: self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) +class BoostedLIFNodes(Nodes): + # Same as LIFNodes, faster: no rest, no reset, no lbound + def __init__( + self, + n: Optional[int] = None, + shape: Optional[Iterable[int]] = None, + traces: bool = False, + traces_additive: bool = False, + tc_trace: Union[float, torch.Tensor] = 20.0, + trace_scale: Union[float, torch.Tensor] = 1.0, + sum_input: bool = False, + thresh: Union[float, torch.Tensor] = 13.0, + refrac: Union[int, torch.Tensor] = 5, + tc_decay: Union[float, torch.Tensor] = 100.0, + **kwargs, + ) -> None: + # language=rst + """ + Instantiates a layer of LIF neurons. + + :param n: The number of neurons in the layer. + :param shape: The dimensionality of the layer. + :param traces: Whether to record spike traces. + :param traces_additive: Whether to record spike traces additively. + :param tc_trace: Time constant of spike trace decay. + :param trace_scale: Scaling factor for spike trace. + :param sum_input: Whether to sum all inputs. + :param thresh: Spike threshold voltage. + :param reset: Post-spike reset voltage. + :param refrac: Refractory (non-firing) period of the neuron. + :param tc_decay: Time constant of neuron voltage decay. + """ + super().__init__( + n=n, + shape=shape, + traces=traces, + traces_additive=traces_additive, + tc_trace=tc_trace, + trace_scale=trace_scale, + sum_input=sum_input, + ) + + self.register_buffer( + "thresh", torch.tensor(thresh, dtype=torch.float) + ) # Spike threshold voltage. + self.register_buffer( + "refrac", torch.tensor(refrac) + ) # Post-spike refractory period. + self.register_buffer( + "tc_decay", torch.tensor(tc_decay, dtype=torch.float) + ) # Time constant of neuron voltage decay. + self.register_buffer( + "decay", torch.zeros(*self.shape) + ) # Set in compute_decays. + self.register_buffer("v", torch.FloatTensor()) # Neuron voltages. + self.register_buffer( + "refrac_count", torch.tensor(0) + ) # Refractory period counters. + + def forward(self, x: torch.Tensor) -> None: + # language=rst + """ + Runs a single simulation step. + + :param x: Inputs to the layer. + """ + # Decay voltages. + self.v *= self.decay + + # Integrate inputs. + if x is not None: + x.masked_fill_(self.refrac_count > 0, 0.0) + + # Decrement refractory counters. + self.refrac_count -= self.dt + + if x is not None: + self.v += x + + # Check for spiking neurons. + self.s = self.v >= self.thresh + + # Refractoriness and voltage reset. + self.refrac_count.masked_fill_(self.s, self.refrac) + self.v.masked_fill_(self.s, 0) + + super().forward(x) + + def reset_state_variables(self) -> None: + # language=rst + """ + Resets relevant state variables. + """ + super().reset_state_variables() + self.v.fill_(0) # Neuron voltages. + self.refrac_count.zero_() # Refractory period counters. + + def compute_decays(self, dt) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super().compute_decays(dt=dt) + self.decay = torch.exp( + -self.dt / self.tc_decay + ) # Neuron voltage decay (per timestep). + + def set_batch_size(self, batch_size) -> None: + # language=rst + """ + Sets mini-batch size. Called when layer is added to a network. + + :param batch_size: Mini-batch size. + """ + super().set_batch_size(batch_size=batch_size) + self.v = torch.zeros(batch_size, *self.shape, device=self.v.device) + self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) + + class CurrentLIFNodes(Nodes): # language=rst """ diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 2432011f..478a8280 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -52,7 +52,7 @@ def __init__( self.source = source self.target = target - self.nu = nu + # self.nu = nu self.weight_decay = weight_decay self.reduction = reduction @@ -163,7 +163,12 @@ def __init__( w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) - self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False) + + b = kwargs.get("b", None) + if b is not None: + self.b = Parameter(b, requires_grad=False) + else: + self.b = None def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst @@ -175,7 +180,10 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. - post = s.float().view(s.size(0), -1) @ self.w + self.b + if self.b is None: + post = s.view(s.size(0), -1).float() @ self.w + else: + post = s.view(s.size(0), -1).float() @ self.w + self.b return post.view(s.size(0), *self.target.shape) def update(self, **kwargs) -> None: