Skip to content
6 changes: 4 additions & 2 deletions bindsnet/encoding/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,18 @@ def __init__(self, time: int, dt: float = 1.0, **kwargs):


class PoissonEncoder(Encoder):
def __init__(self, time: int, dt: float = 1.0, **kwargs):
def __init__(self, time: int, dt: float = 1.0, approx: bool = False, **kwargs):
# language=rst
"""
Creates a callable PoissonEncoder which encodes as defined in
``bindsnet.encoding.poisson`

:param time: Length of Poisson spike train per input variable.
:param dt: Simulation time step.
:param approx: Bool: use alternate faster, less accurate computation.

"""
super().__init__(time, dt=dt, **kwargs)
super().__init__(time, dt=dt, approx=approx, **kwargs)

self.enc = encodings.poisson

Expand Down
57 changes: 36 additions & 21 deletions bindsnet/encoding/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ def bernoulli(


def poisson(
datum: torch.Tensor, time: int, dt: float = 1.0, device="cpu", **kwargs
datum: torch.Tensor,
time: int,
dt: float = 1.0,
device="cpu",
approx=False,
**kwargs
) -> torch.Tensor:
# language=rst
"""
Expand All @@ -110,6 +115,8 @@ def poisson(
:param datum: Tensor of shape ``[n_1, ..., n_k]``.
:param time: Length of Poisson spike train per input variable.
:param dt: Simulation time step.
:param device: target destination of poisson spikes.
:param approx: Bool: use alternate faster, less accurate computation.
:return: Tensor of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes.
"""
assert (datum >= 0).all(), "Inputs must be non-negative"
Expand All @@ -119,27 +126,35 @@ def poisson(
datum = datum.flatten()
time = int(time / dt)

# Compute firing rates in seconds as function of data intensity,
# accounting for simulation time step.
rate = torch.zeros(size, device=device)
rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt)
if approx:
# random normal power awful approximation
x = torch.randn((time, size), device=device).abs()
x = torch.pow(x, (datum * 0.11 + 5) / 50)
y = torch.tensor(x < 0.6, dtype=torch.bool, device=device)

# Create Poisson distribution and sample inter-spike intervals
# (incrementing by 1 to avoid zero intervals).
dist = torch.distributions.Poisson(rate=rate)
intervals = dist.sample(sample_shape=torch.Size([time + 1]))
intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float()

# Calculate spike times by cumulatively summing over time dimension.
times = torch.cumsum(intervals, dim=0).long()
times[times >= time + 1] = 0

# Create tensor of spikes.
spikes = torch.zeros(time + 1, size, device=device).byte()
spikes[times, torch.arange(size)] = 1
spikes = spikes[1:]

return spikes.view(time, *shape)
return y.view(time, *shape).byte()
else:
# Compute firing rates in seconds as function of data intensity,
# accounting for simulation time step.
rate = torch.zeros(size, device=device)
rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt)

# Create Poisson distribution and sample inter-spike intervals
# (incrementing by 1 to avoid zero intervals).
dist = torch.distributions.Poisson(rate=rate)
intervals = dist.sample(sample_shape=torch.Size([time + 1]))
intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float()

# Calculate spike times by cumulatively summing over time dimension.
times = torch.cumsum(intervals, dim=0).long()
times[times >= time + 1] = 0

# Create tensor of spikes.
spikes = torch.zeros(time + 1, size, device=device).byte()
spikes[times, torch.arange(size)] = 1
spikes = spikes[1:]

return spikes.view(time, *shape)


def rank_order(
Expand Down
27 changes: 15 additions & 12 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(
elif isinstance(nu, float) or isinstance(nu, int):
nu = [nu, nu]

self.nu = nu
self.nu = torch.zeros(2, dtype=torch.float)
self.nu[0] = nu[0]
self.nu[1] = nu[1]

# Parameter update reduction across minibatch dimension.
if reduction is None:
Expand All @@ -64,7 +66,7 @@ def __init__(
self.reduction = reduction

# Weight decay.
self.weight_decay = weight_decay
self.weight_decay = 1.0 - weight_decay

def update(self) -> None:
# language=rst
Expand All @@ -73,7 +75,7 @@ def update(self) -> None:
"""
# Implement weight decay.
if self.weight_decay:
self.connection.w -= self.weight_decay * self.connection.w
self.connection.w *= self.weight_decay

# Bound weights.
if (
Expand Down Expand Up @@ -177,20 +179,21 @@ def _connection_update(self, **kwargs) -> None:
"""
batch_size = self.source.batch_size

source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
target_s = self.target.s.view(batch_size, -1).unsqueeze(1).float()
target_x = self.target.x.view(batch_size, -1).unsqueeze(1)

# Pre-synaptic update.
if self.nu[0]:
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
self.connection.w -= self.nu[0] * update
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
del source_s, target_x

# Post-synaptic update.
if self.nu[1]:
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
self.connection.w += self.nu[1] * update
target_s = (
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
)
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
del source_x, target_s

super().update()

Expand Down
123 changes: 121 additions & 2 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, x: torch.Tensor) -> None:
if self.traces_additive:
self.x += self.trace_scale * self.s.float()
else:
self.x.masked_fill_(self.s != 0, 1)
self.x.masked_fill_(self.s, 1)

if self.sum_input:
# Add current input to running sum.
Expand All @@ -125,7 +125,7 @@ def compute_decays(self, dt) -> None:
"""
Abstract base class method for setting decays.
"""
self.dt = dt
self.dt = torch.tensor(dt)
if self.traces:
self.trace_decay = torch.exp(
-self.dt / self.tc_trace
Expand Down Expand Up @@ -560,6 +560,125 @@ def set_batch_size(self, batch_size) -> None:
self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)


class BoostedLIFNodes(Nodes):
# Same as LIFNodes, faster: no rest, no reset, no lbound
def __init__(
self,
n: Optional[int] = None,
shape: Optional[Iterable[int]] = None,
traces: bool = False,
traces_additive: bool = False,
tc_trace: Union[float, torch.Tensor] = 20.0,
trace_scale: Union[float, torch.Tensor] = 1.0,
sum_input: bool = False,
thresh: Union[float, torch.Tensor] = 13.0,
refrac: Union[int, torch.Tensor] = 5,
tc_decay: Union[float, torch.Tensor] = 100.0,
**kwargs,
) -> None:
# language=rst
"""
Instantiates a layer of LIF neurons.

:param n: The number of neurons in the layer.
:param shape: The dimensionality of the layer.
:param traces: Whether to record spike traces.
:param traces_additive: Whether to record spike traces additively.
:param tc_trace: Time constant of spike trace decay.
:param trace_scale: Scaling factor for spike trace.
:param sum_input: Whether to sum all inputs.
:param thresh: Spike threshold voltage.
:param reset: Post-spike reset voltage.
:param refrac: Refractory (non-firing) period of the neuron.
:param tc_decay: Time constant of neuron voltage decay.
"""
super().__init__(
n=n,
shape=shape,
traces=traces,
traces_additive=traces_additive,
tc_trace=tc_trace,
trace_scale=trace_scale,
sum_input=sum_input,
)

self.register_buffer(
"thresh", torch.tensor(thresh, dtype=torch.float)
) # Spike threshold voltage.
self.register_buffer(
"refrac", torch.tensor(refrac)
) # Post-spike refractory period.
self.register_buffer(
"tc_decay", torch.tensor(tc_decay, dtype=torch.float)
) # Time constant of neuron voltage decay.
self.register_buffer(
"decay", torch.zeros(*self.shape)
) # Set in compute_decays.
self.register_buffer("v", torch.FloatTensor()) # Neuron voltages.
self.register_buffer(
"refrac_count", torch.tensor(0)
) # Refractory period counters.

def forward(self, x: torch.Tensor) -> None:
# language=rst
"""
Runs a single simulation step.

:param x: Inputs to the layer.
"""
# Decay voltages.
self.v *= self.decay

# Integrate inputs.
if x is not None:
x.masked_fill_(self.refrac_count > 0, 0.0)

# Decrement refractory counters.
self.refrac_count -= self.dt

if x is not None:
self.v += x

# Check for spiking neurons.
self.s = self.v >= self.thresh

# Refractoriness and voltage reset.
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, 0)

super().forward(x)

def reset_state_variables(self) -> None:
# language=rst
"""
Resets relevant state variables.
"""
super().reset_state_variables()
self.v.fill_(0) # Neuron voltages.
self.refrac_count.zero_() # Refractory period counters.

def compute_decays(self, dt) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super().compute_decays(dt=dt)
self.decay = torch.exp(
-self.dt / self.tc_decay
) # Neuron voltage decay (per timestep).

def set_batch_size(self, batch_size) -> None:
# language=rst
"""
Sets mini-batch size. Called when layer is added to a network.

:param batch_size: Mini-batch size.
"""
super().set_batch_size(batch_size=batch_size)
self.v = torch.zeros(batch_size, *self.shape, device=self.v.device)
self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)


class CurrentLIFNodes(Nodes):
# language=rst
"""
Expand Down
14 changes: 11 additions & 3 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.source = source
self.target = target

self.nu = nu
# self.nu = nu
self.weight_decay = weight_decay
self.reduction = reduction

Expand Down Expand Up @@ -163,7 +163,12 @@ def __init__(
w = torch.clamp(w, self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False)

b = kwargs.get("b", None)
if b is not None:
self.b = Parameter(b, requires_grad=False)
else:
self.b = None

def compute(self, s: torch.Tensor) -> torch.Tensor:
# language=rst
Expand All @@ -175,7 +180,10 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
decaying spike activation).
"""
# Compute multiplication of spike activations by weights and add bias.
post = s.float().view(s.size(0), -1) @ self.w + self.b
if self.b is None:
post = s.view(s.size(0), -1).float() @ self.w
else:
post = s.view(s.size(0), -1).float() @ self.w + self.b
return post.view(s.size(0), *self.target.shape)

def update(self, **kwargs) -> None:
Expand Down