diff --git a/.gitignore b/.gitignore index 4c3b8d2e..8859d773 100644 --- a/.gitignore +++ b/.gitignore @@ -122,6 +122,9 @@ examples/saved_checkpoints # PyCharm project folder. .idea/ +# VS Code workspace +.vscode/ + # macOS .DS_Store @@ -129,3 +132,6 @@ figures/ # Analyzer log default directory. logs/ + +# PyTorch tensorboard log default directory. +runs/ \ No newline at end of file diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 1dd79aa2..b1e65c30 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -27,7 +27,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -59,9 +59,9 @@ def __init__( if (self.nu == torch.zeros(2)).all() and not isinstance(self, NoOp): warnings.warn( - f"nu is set to [0., 0.] for {type(self).__name__} learning rule. " + - "It will disable the learning process." - ) + f"nu is set to [0., 0.] for {type(self).__name__} learning rule. " + + "It will disable the learning process." + ) # Parameter update reduction across minibatch dimension. if reduction is None: @@ -86,7 +86,8 @@ def update(self) -> None: # Bound weights. if ( - self.connection.wmin != -np.inf or self.connection.wmax != np.inf + (self.connection.wmin != -np.inf).any() + or (self.connection.wmax != np.inf).any() ) and not isinstance(self, NoOp): self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) @@ -103,7 +104,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -120,7 +121,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) def update(self, **kwargs) -> None: @@ -144,7 +145,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -162,7 +163,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) assert ( @@ -260,7 +261,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -278,13 +279,13 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) assert self.source.traces, "Pre-synaptic nodes must record spike traces." - assert ( - connection.wmin != -np.inf and connection.wmax != np.inf - ), "Connection must define finite wmin and wmax." + assert (connection.wmin != -np.inf).any() and ( + connection.wmax != np.inf + ).any(), "Connection must define finite wmin and wmax." self.wmin = connection.wmin self.wmax = connection.wmax @@ -398,7 +399,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -416,7 +417,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) assert ( @@ -503,7 +504,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -527,7 +528,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) if isinstance(connection, (Connection, LocalConnection)): @@ -697,7 +698,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -722,7 +723,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) if isinstance(connection, (Connection, LocalConnection)): @@ -901,7 +902,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -926,7 +927,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) # Trace is needed for computing epsilon. diff --git a/bindsnet/network/network.py b/bindsnet/network/network.py index bb7ef866..0bf02a08 100644 --- a/bindsnet/network/network.py +++ b/bindsnet/network/network.py @@ -231,7 +231,7 @@ def _get_inputs(self, layers: Iterable = None) -> Dict[str, torch.Tensor]: self.batch_size, target.res_window_size, *target.shape, - device=target.s.device + device=target.s.device, ) else: inputs[c[1]] = torch.zeros( @@ -308,8 +308,9 @@ def run( plt.show() """ # Check input type - assert type(inputs) == dict, ("'inputs' must be a dict of names of layers " + - f"(str) and relevant input tensors. Got {type(inputs).__name__} instead." + assert type(inputs) == dict, ( + "'inputs' must be a dict of names of layers " + + f"(str) and relevant input tensors. Got {type(inputs).__name__} instead." ) # Parse keyword arguments. clamps = kwargs.get("clamp", {}) diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 0fe16fb7..d65f82d6 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -23,7 +23,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -40,8 +40,10 @@ def __init__( :param LearningRule update_rule: Modifies connection parameters according to some rule. - :param float wmin: The minimum value on the connection weights. - :param float wmax: The maximum value on the connection weights. + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w :param float norm: Total weight per target neuron normalization. """ super().__init__() @@ -59,8 +61,16 @@ def __init__( from ..learning import NoOp self.update_rule = kwargs.get("update_rule", NoOp) - self.wmin = kwargs.get("wmin", -np.inf) - self.wmax = kwargs.get("wmax", np.inf) + + # Float32 necessary for comparisons with +/-inf + self.wmin = Parameter( + torch.as_tensor(kwargs.get("wmin", -np.inf), dtype=torch.float32), + requires_grad=False, + ) + self.wmax = Parameter( + torch.as_tensor(kwargs.get("wmax", np.inf), dtype=torch.float32), + requires_grad=False, + ) self.norm = kwargs.get("norm", None) self.decay = kwargs.get("decay", None) @@ -72,7 +82,7 @@ def __init__( nu=nu, reduction=reduction, weight_decay=weight_decay, - **kwargs + **kwargs, ) @abstractmethod @@ -127,7 +137,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -146,20 +156,22 @@ def __init__( some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. - :param float wmin: Minimum allowed value on the connection weights. - :param float wmax: Maximum allowed value on the connection weights. + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w :param float norm: Total weight per target neuron normalization constant. """ super().__init__(source, target, nu, reduction, weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: - if self.wmin == -np.inf or self.wmax == np.inf: + if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: - if self.wmin != -np.inf or self.wmax != np.inf: + if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any(): w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) @@ -260,7 +272,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -283,8 +295,10 @@ def __init__( some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. - :param float wmin: Minimum allowed value on the connection weights. - :param float wmax: Maximum allowed value on the connection weights. + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or + tensor of same size as w :param float norm: Total weight per target neuron normalization constant. """ super().__init__(source, target, nu, reduction, weight_decay, **kwargs) @@ -326,8 +340,9 @@ def __init__( ), error w = kwargs.get("w", None) + inf = torch.tensor(np.inf) if w is None: - if self.wmin == -np.inf or self.wmax == np.inf: + if (self.wmin == -inf).any() or (self.wmax == inf).any(): w = torch.clamp( torch.rand(self.out_channels, self.in_channels, *self.kernel_size), self.wmin, @@ -339,7 +354,7 @@ def __init__( ) w += self.wmin else: - if self.wmin != -np.inf or self.wmax != np.inf: + if (self.wmin == -inf).any() or (self.wmax == inf).any(): w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) @@ -410,7 +425,7 @@ def __init__( stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -500,7 +515,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -528,8 +543,10 @@ def __init__( some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. - :param float wmin: Minimum allowed value on the connection weights. - :param float wmax: Maximum allowed value on the connection weights. + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or + tensor of same size as w :param float norm: Total weight per target neuron normalization constant. :param Tuple[int, int] input_shape: Shape of input population if it's not ``[sqrt, sqrt]``. @@ -561,9 +578,10 @@ def __init__( conv_prod = int(np.prod(conv_size)) kernel_prod = int(np.prod(kernel_size)) - assert ( - target.n == n_filters * conv_prod - ), "Target layer size must be n_filters * (kernel_size ** 2)." + assert target.n == n_filters * conv_prod, ( + f"Total neurons in target layer must be {n_filters * conv_prod}. " + f"Got {target.n}." + ) locations = torch.zeros( kernel_size[0], kernel_size[1], conv_size[0], conv_size[1] @@ -584,20 +602,21 @@ def __init__( w = kwargs.get("w", None) if w is None: + # Calculate unbounded weights w = torch.zeros(source.n, target.n) for f in range(n_filters): for c in range(conv_prod): for k in range(kernel_prod): - if self.wmin == -np.inf or self.wmax == np.inf: - w[self.locations[k, c], f * conv_prod + c] = np.clip( - np.random.rand(), self.wmin, self.wmax - ) - else: - w[ - self.locations[k, c], f * conv_prod + c - ] = self.wmin + np.random.rand() * (self.wmax - self.wmin) + w[self.locations[k, c], f * conv_prod + c] = np.random.rand() + + # Bind weights to given range + if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): + w = torch.clamp(w, self.wmin, self.wmax) + else: + w = self.wmin + w * (self.wmax - self.wmin) + else: - if self.wmin != -np.inf or self.wmax != np.inf: + if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any(): w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) @@ -671,7 +690,7 @@ def __init__( target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -683,21 +702,23 @@ def __init__( Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. - :param torch.Tensor w: Strengths of synapses. - :param float wmin: Minimum allowed value on the connection weights. - :param float wmax: Maximum allowed value on the connection weights. + :param Union[float, torch.Tensor] w: Strengths of synapses. Can be single value or tensor of size ``target`` + :param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or + tensor of same size as w + :param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or + tensor of same size as w :param float norm: Total weight per target neuron normalization constant. """ super().__init__(source, target, nu, weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: - if self.wmin == -np.inf or self.wmax == np.inf: + if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax) else: w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin) else: - if self.wmin != -np.inf or self.wmax != np.inf: + if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) @@ -752,7 +773,7 @@ def __init__( nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = None, - **kwargs + **kwargs, ) -> None: # language=rst """ @@ -767,7 +788,7 @@ def __init__( Keyword arguments: - :param torch.Tensor w: Strengths of synapses. + :param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format :param float sparsity: Fraction of sparse connections to use. :param LearningRule update_rule: Modifies connection parameters according to some rule. @@ -791,16 +812,17 @@ def __init__( i = torch.bernoulli( 1 - self.sparsity * torch.ones(*source.shape, *target.shape) ) - if self.wmin == -np.inf or self.wmax == np.inf: + if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any(): v = torch.clamp( - torch.rand(*source.shape, *target.shape)[i.bool()], + torch.rand(*source.shape, *target.shape), self.wmin, self.wmax, - ) + )[i.bool()] else: - v = self.wmin + torch.rand(*source.shape, *target.shape)[i.bool()] * ( - self.wmax - self.wmin - ) + v = ( + self.wmin + + torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin) + )[i.bool()] w = torch.sparse.FloatTensor(i.nonzero().t(), v) elif w is not None and self.sparsity is None: assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)" @@ -818,7 +840,8 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ - return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1) + return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1) + # return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1) def update(self, **kwargs) -> None: # language=rst diff --git a/test/network/test_connections.py b/test/network/test_connections.py index 6f76cb28..3d201435 100644 --- a/test/network/test_connections.py +++ b/test/network/test_connections.py @@ -1,15 +1,33 @@ import torch -from bindsnet.network.nodes import LIFNodes +from bindsnet.network import Network +from bindsnet.network.nodes import Input, LIFNodes, SRM0Nodes from bindsnet.network.topology import * +from bindsnet.learning import ( + Hebbian, + PostPre, + WeightDependentPostPre, + MSTDP, + MSTDPET, + Rmax, + NoOp, +) + class TestConnection: """ Tests all stable groups of neurons / nodes. """ + def __init__(self): + if torch.cuda.is_available(): + self.device = torch.device("cuda:0") + else: + self.device = torch.device("cpu:0") + print(f"Using device '{self.device}' for the test") + def test_transfer(self): if not torch.cuda.is_available(): return @@ -29,7 +47,7 @@ def test_transfer(self): l_b = LIFNodes(shape=[1, 26, 26]) connection = conn_type(l_a, l_b, *args, **kwargs) - connection.to(torch.device("cuda:0")) + connection.to() connection_tensors = [ k @@ -54,8 +72,136 @@ def test_transfer(self): print(d, d == torch.device("cuda:0")) assert d == torch.device("cuda:0") + def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args, **kwargs): + print("Testing:", conn_type) + time = 100 + weights = [None, torch.Tensor(*shape_w)] + wmins = [ + -np.inf, + 0, + torch.zeros(*shape_w), + torch.zeros(*shape_w).masked_fill( + torch.bernoulli(torch.rand(*shape_w)) == 1, -np.inf + ), + ] + wmaxes = [ + np.inf, + 0, + torch.ones(*shape_w), + torch.randn(*shape_w).masked_fill( + torch.bernoulli(torch.rand(*shape_w)) == 1, np.inf + ), + ] + update_rule = kwargs.get("update_rule", None) + for w in weights: + for wmin in wmins: + for wmax in wmaxes: + + ### Conditional checks ### + # WeightDependentPostPre does not handle infinite ranges + if ( + (torch.tensor(wmin, dtype=torch.float32) == -np.inf).any() + or (torch.tensor(wmax, dtype=torch.float32) == np.inf).any() + ) and update_rule == WeightDependentPostPre: + continue + + # Rmax only supported for Connection & LocalConnection + elif ( + not (conn_type == Connection or conn_type == LocalConnection) + and update_rule == Rmax + ): + return + + print( + f"- w: {type(w).__name__}, " + f"wmin: {type(wmax).__name__}, wmax: {type(wmax).__name__}" + ) + if kwargs.get("update_rule") == Rmax: + l_a = SRM0Nodes( + shape=shape_a, traces=True, traces_additive=True + ) + l_b = SRM0Nodes( + shape=shape_b, traces=True, traces_additive=True + ) + else: + l_a = LIFNodes(shape=shape_a, traces=True, traces_additive=True) + l_b = LIFNodes(shape=shape_b, traces=True, traces_additive=True) + + ### Create network ### + network = Network(dt=1.0) + network.add_layer( + Input(n=100, traces=True, traces_additive=True), name="input" + ) + network.add_layer(l_a, name="a") + network.add_layer(l_b, name="b") + + network.add_connection( + conn_type(l_a, l_b, w=w, wmin=wmin, wmax=wmax, *args, **kwargs), + source="a", + target="b", + ) + network.add_connection( + Connection( + wmin=0, + wmax=1, + source=network.layers["input"], + target=network.layers["a"], + **kwargs, + ), + source="input", + target="a", + ) + + ### Run network ### + network.run( + inputs={"input": torch.bernoulli(torch.rand(time, 100)).byte()}, + time=time, + reward=1, + ) + if __name__ == "__main__": tester = TestConnection() - tester.test_transfer() + # tester.test_transfer() + + # Connections with learning ability + conn_types = [Connection, Conv2dConnection, LocalConnection] + args = [ + [[100], [50], (100, 50)], + [ + [1, 28, 28], + [1, 26, 26], + (1, 1, 3, 3), + 3, + ], + [ + [1, 28, 28], + [1, 26, 26], + (784, 676), + 3, + 1, + 1, + ], + ] + for update_rule in ( + Hebbian, + PostPre, + WeightDependentPostPre, + MSTDP, + MSTDPET, + Rmax, + ): + print("Learning Rule:", update_rule) + for conn_type, arg in zip(conn_types, args): + tester.test_weights(conn_type, nu=1e-2, update_rule=update_rule, *arg) + + # Other connections + # Note: Does not include MaxPool2dConnection because this connection + # does not utilize weights and wmin/wmax + conn_types = [MeanFieldConnection] + args = [ + [[1, 28, 28], [1, 26, 26], (1, 26), 3, 1], + ] + for conn_type, arg in zip(conn_types, args): + tester.test_weights(conn_type, decay=1, update_rule=NoOp, *arg)